diff --git a/eng/common/tsp-client/package-lock.json b/eng/common/tsp-client/package-lock.json index f0d729273161..d3a911a3bfe0 100644 --- a/eng/common/tsp-client/package-lock.json +++ b/eng/common/tsp-client/package-lock.json @@ -214,7 +214,6 @@ "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.29.0.tgz", "integrity": "sha512-9NhCeYjq9+3uxgdtp20LSiJXJvN0FeCtNGpJxuMFZ1Kv3cWUNb6DOhJwUvcVCzKGR66cw4njwM6hrJLqgOwbcw==", "license": "MIT", - "peer": true, "dependencies": { "@babel/helper-validator-identifier": "^7.28.5", "js-tokens": "^4.0.0", @@ -229,7 +228,6 @@ "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.28.5.tgz", "integrity": "sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==", "license": "MIT", - "peer": true, "engines": { "node": ">=6.9.0" } @@ -239,7 +237,6 @@ "resolved": "https://registry.npmjs.org/@inquirer/ansi/-/ansi-2.0.5.tgz", "integrity": "sha512-doc2sWgJpbFQ64UflSVd17ibMGDuxO1yKgOgLMwavzESnXjFWJqUeG8saYosqKpHp4kWiM5x1nXvEjbpx90gzw==", "license": "MIT", - "peer": true, "engines": { "node": ">=23.5.0 || ^22.13.0 || ^21.7.0 || ^20.12.0" } @@ -249,7 +246,6 @@ "resolved": "https://registry.npmjs.org/@inquirer/checkbox/-/checkbox-5.1.3.tgz", "integrity": "sha512-+G7I8CT+EHv/hasNfUl3P37DVoMoZfpA+2FXmM54dA8MxYle1YqucxbacxHalw1iAFSdKNEDTGNV7F+j1Ldqcg==", "license": "MIT", - "peer": true, "dependencies": { "@inquirer/ansi": "^2.0.5", "@inquirer/core": "^11.1.8", @@ -273,7 +269,6 @@ "resolved": "https://registry.npmjs.org/@inquirer/confirm/-/confirm-6.0.11.tgz", "integrity": "sha512-pTpHjg0iEIRMYV/7oCZUMf27/383E6Wyhfc/MY+AVQGEoUobffIYWOK9YLP2XFRGz/9i6WlTQh1CkFVIo2Y7XA==", "license": "MIT", - "peer": true, "dependencies": { "@inquirer/core": "^11.1.8", "@inquirer/type": "^4.0.5" @@ -295,7 +290,6 @@ "resolved": "https://registry.npmjs.org/@inquirer/core/-/core-11.1.8.tgz", "integrity": "sha512-/u+yJk2pOKNDOh1ZgdUH2RQaRx6OOH4I0uwL95qPvTFTIL38YBsuSC4r1yXBB3Q6JvNqFFc202gk0Ew79rrcjA==", "license": "MIT", - "peer": true, "dependencies": { "@inquirer/ansi": "^2.0.5", "@inquirer/figures": "^2.0.5", @@ -322,7 +316,6 @@ "resolved": "https://registry.npmjs.org/@inquirer/editor/-/editor-5.1.0.tgz", "integrity": "sha512-6wlkYl65Qfayy48gPCfU4D7li6KCAGN79mLXa/tYHZH99OfZ820yY+HA+DgE88r8YwwgeuY6PQgNqMeK6LuMmw==", "license": "MIT", - "peer": true, "dependencies": { "@inquirer/core": "^11.1.8", "@inquirer/external-editor": "^3.0.0", @@ -345,7 +338,6 @@ "resolved": "https://registry.npmjs.org/@inquirer/expand/-/expand-5.0.12.tgz", "integrity": "sha512-vOfrB33b7YIZfDauXS8vNNz2Z86FozTZLIt7e+7/dCaPJ1RXZsHCuI9TlcERzEUq57vkM+UdnBgxP0rFd23JYQ==", "license": "MIT", - "peer": true, "dependencies": { "@inquirer/core": "^11.1.8", "@inquirer/type": "^4.0.5" @@ -367,7 +359,6 @@ "resolved": "https://registry.npmjs.org/@inquirer/external-editor/-/external-editor-3.0.0.tgz", "integrity": "sha512-lDSwMgg+M5rq6JKBYaJwSX6T9e/HK2qqZ1oxmOwn4AQoJE5D+7TumsxLGC02PWS//rkIVqbZv3XA3ejsc9FYvg==", "license": "MIT", - "peer": true, "dependencies": { "chardet": "^2.1.1", "iconv-lite": "^0.7.2" @@ -389,7 +380,6 @@ "resolved": "https://registry.npmjs.org/@inquirer/figures/-/figures-2.0.5.tgz", "integrity": "sha512-NsSs4kzfm12lNetHwAn3GEuH317IzpwrMCbOuMIVytpjnJ90YYHNwdRgYGuKmVxwuIqSgqk3M5qqQt1cDk0tGQ==", "license": "MIT", - "peer": true, "engines": { "node": ">=23.5.0 || ^22.13.0 || ^21.7.0 || ^20.12.0" } @@ -399,7 +389,6 @@ "resolved": "https://registry.npmjs.org/@inquirer/input/-/input-5.0.11.tgz", "integrity": "sha512-twUWidn4ocPO8qi6fRM7tNWt7W1FOnOZqQ+/+PsfLUacMR5rFLDPK9ql0nBPwxi0oELbo8T5NhRs8B2+qQEqFQ==", "license": "MIT", - "peer": true, "dependencies": { "@inquirer/core": "^11.1.8", "@inquirer/type": "^4.0.5" @@ -421,7 +410,6 @@ "resolved": "https://registry.npmjs.org/@inquirer/number/-/number-4.0.11.tgz", "integrity": "sha512-Vscmim9TCksQsfjPtka/JwPUcbLhqWYrgfPf1cHrCm24X/F2joFwnageD50yMKsaX14oNGOyKf/RNXAFkNjWpA==", "license": "MIT", - "peer": true, "dependencies": { "@inquirer/core": "^11.1.8", "@inquirer/type": "^4.0.5" @@ -443,7 +431,6 @@ "resolved": "https://registry.npmjs.org/@inquirer/password/-/password-5.0.11.tgz", "integrity": "sha512-9KZFeRaNHIcejtPb0wN4ddFc7EvobVoAFa049eS3LrDZFxI8O7xUXiITEOinBzkZFAIwY5V4yzQae/QfO9cbbg==", "license": "MIT", - "peer": true, "dependencies": { "@inquirer/ansi": "^2.0.5", "@inquirer/core": "^11.1.8", @@ -466,7 +453,6 @@ "resolved": "https://registry.npmjs.org/@inquirer/prompts/-/prompts-8.4.1.tgz", "integrity": "sha512-AH5xPQ997K7e0F0vulPlteIHke2awMkFi8F0dBemrDfmvtPmHJo82mdHbONC4F/t8d1NHwrbI5cGVI+RbLWdoQ==", "license": "MIT", - "peer": true, "dependencies": { "@inquirer/checkbox": "^5.1.3", "@inquirer/confirm": "^6.0.11", @@ -496,7 +482,6 @@ "resolved": "https://registry.npmjs.org/@inquirer/rawlist/-/rawlist-5.2.7.tgz", "integrity": "sha512-AqRMiD9+uE1lskDPrdqHwrV/EUmxKEBLX44SR7uxK3vD2413AmVfE5EQaPeNzYf5Pq5SitHJDYUFVF0poIr09w==", "license": "MIT", - "peer": true, "dependencies": { "@inquirer/core": "^11.1.8", "@inquirer/type": "^4.0.5" @@ -518,7 +503,6 @@ "resolved": "https://registry.npmjs.org/@inquirer/search/-/search-4.1.7.tgz", "integrity": "sha512-1y7+0N65AWk5RdlXH/Kn13txf3IjIQ7OEfhCEkDTU+h5wKMLq8DUF3P6z+/kLSxDGDtQT1dRBWEUC3o/VvImsQ==", "license": "MIT", - "peer": true, "dependencies": { "@inquirer/core": "^11.1.8", "@inquirer/figures": "^2.0.5", @@ -541,7 +525,6 @@ "resolved": "https://registry.npmjs.org/@inquirer/select/-/select-5.1.3.tgz", "integrity": "sha512-zYyqWgGQi3NhBcNq4Isc5rB3oEdQEh1Q/EcAnOW0FK4MpnXWkvSBYgA4cYrTM4A9UB573omouZbnL9JJ74Mq3A==", "license": "MIT", - "peer": true, "dependencies": { "@inquirer/ansi": "^2.0.5", "@inquirer/core": "^11.1.8", @@ -565,7 +548,6 @@ "resolved": "https://registry.npmjs.org/@inquirer/type/-/type-4.0.5.tgz", "integrity": "sha512-aetVUNeKNc/VriqXlw1NRSW0zhMBB0W4bNbWRJgzRl/3d0QNDQFfk0GO5SDdtjMZVg6o8ZKEiadd7SCCzoOn5Q==", "license": "MIT", - "peer": true, "engines": { "node": ">=23.5.0 || ^22.13.0 || ^21.7.0 || ^20.12.0" }, @@ -583,7 +565,6 @@ "resolved": "https://registry.npmjs.org/@isaacs/fs-minipass/-/fs-minipass-4.0.1.tgz", "integrity": "sha512-wgm9Ehl2jpeqP3zw/7mo3kRHFp5MEDhqAdwy1fTGkHAwnkGOVsgpvQhL8B5n1qlb01jV3n/bI0ZfZp5lWA1k4w==", "license": "ISC", - "peer": true, "dependencies": { "minipass": "^7.0.4" }, @@ -611,7 +592,6 @@ "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", "license": "MIT", - "peer": true, "dependencies": { "@nodelib/fs.stat": "2.0.5", "run-parallel": "^1.1.9" @@ -625,7 +605,6 @@ "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", "license": "MIT", - "peer": true, "engines": { "node": ">= 8" } @@ -635,7 +614,6 @@ "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", "license": "MIT", - "peer": true, "dependencies": { "@nodelib/fs.scandir": "2.1.5", "fastq": "^1.6.0" @@ -664,7 +642,6 @@ "resolved": "https://registry.npmjs.org/@sindresorhus/merge-streams/-/merge-streams-4.0.0.tgz", "integrity": "sha512-tlqY9xq5ukxTUZBmoOp+m61cqwQD5pHJtFY3Mn8CA8ps6yghLH/Hw8UPdqg4OLmFW3IFlcXnQNmo/dh8HzXYIQ==", "license": "MIT", - "peer": true, "engines": { "node": ">=18" }, @@ -725,7 +702,6 @@ "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.2.2.tgz", "integrity": "sha512-Bq3SmSpyFHaWjPk8If9yc6svM8c56dB5BAtW4Qbw5jHTwwXXcTLoRMkpDJp6VL0XzlWaCHTXrkFURMYmD0sLqg==", "license": "MIT", - "peer": true, "engines": { "node": ">=12" }, @@ -738,7 +714,6 @@ "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-6.2.3.tgz", "integrity": "sha512-4Dj6M28JB+oAH8kFkTLUo+a2jwOFkuqb3yucU0CANcRRUbxS0cP0nZYCGjcc3BNXwRIsUVmDGgzawme7zvJHvg==", "license": "MIT", - "peer": true, "engines": { "node": ">=12" }, @@ -751,7 +726,6 @@ "resolved": "https://registry.npmjs.org/cliui/-/cliui-9.0.1.tgz", "integrity": "sha512-k7ndgKhwoQveBL+/1tqGJYNz097I7WOvwbmmU2AR5+magtbjPWQTS1C5vzGkBC8Ym8UWRzfKUzUUqFLypY4Q+w==", "license": "ISC", - "peer": true, "dependencies": { "string-width": "^7.2.0", "strip-ansi": "^7.1.0", @@ -765,15 +739,13 @@ "version": "10.6.0", "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-10.6.0.tgz", "integrity": "sha512-toUI84YS5YmxW219erniWD0CIVOo46xGKColeNQRgOzDorgBi1v4D71/OFzgD9GO2UGKIv1C3Sp8DAn0+j5w7A==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/@typespec/compiler/node_modules/string-width": { "version": "7.2.0", "resolved": "https://registry.npmjs.org/string-width/-/string-width-7.2.0.tgz", "integrity": "sha512-tsaTIkKW9b4N+AEj+SVA+WhJzV7/zMhcSu78mLKWSk7cXMOSHsBKFWUs0fWwq8QyK3MgJBQRX6Gbi4kYbdvGkQ==", "license": "MIT", - "peer": true, "dependencies": { "emoji-regex": "^10.3.0", "get-east-asian-width": "^1.0.0", @@ -791,7 +763,6 @@ "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-7.2.0.tgz", "integrity": "sha512-yDPMNjp4WyfYBkHnjIRLfca1i6KMyGCtsVgoKe/z1+6vukgaENdgGBZt+ZmKPc4gavvEZ5OgHfHdrazhgNyG7w==", "license": "MIT", - "peer": true, "dependencies": { "ansi-regex": "^6.2.2" }, @@ -807,7 +778,6 @@ "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-9.0.2.tgz", "integrity": "sha512-42AtmgqjV+X1VpdOfyTGOYRi0/zsoLqtXQckTmqTeybT+BDIbM/Guxo7x3pE2vtpr1ok6xRqM9OpBe+Jyoqyww==", "license": "MIT", - "peer": true, "dependencies": { "ansi-styles": "^6.2.1", "string-width": "^7.0.0", @@ -825,7 +795,6 @@ "resolved": "https://registry.npmjs.org/yargs/-/yargs-18.0.0.tgz", "integrity": "sha512-4UEqdc2RYGHZc7Doyqkrqiln3p9X2DZVxaGbwhn2pi7MrRagKaOcIKe8L3OxYcbhXLgLFUS3zAYuQjKBQgmuNg==", "license": "MIT", - "peer": true, "dependencies": { "cliui": "^9.0.1", "escalade": "^3.1.1", @@ -843,7 +812,6 @@ "resolved": "https://registry.npmjs.org/yargs-parser/-/yargs-parser-22.0.0.tgz", "integrity": "sha512-rwu/ClNdSMpkSrUb+d6BRsSkLUq1fmfsY6TOpYzTwvwkg1/NRG85KBy3kq++A8LKQwX6lsu+aWad+2khvuXrqw==", "license": "ISC", - "peer": true, "engines": { "node": "^20.19.0 || ^22.12.0 || >=23" } @@ -991,7 +959,6 @@ "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz", "integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==", "license": "MIT", - "peer": true, "dependencies": { "fast-deep-equal": "^3.1.3", "fast-uri": "^3.0.1", @@ -1032,7 +999,6 @@ "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "license": "MIT", - "peer": true, "dependencies": { "fill-range": "^7.1.1" }, @@ -1056,22 +1022,19 @@ "version": "5.4.4", "resolved": "https://registry.npmjs.org/change-case/-/change-case-5.4.4.tgz", "integrity": "sha512-HRQyTk2/YPEkt9TnUPbOpr64Uw3KOicFWPVBb+xiHvd6eBx/qPr9xqfBFDT8P2vWsvvz4jbEkfDe71W3VyNu2w==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/chardet": { "version": "2.1.1", "resolved": "https://registry.npmjs.org/chardet/-/chardet-2.1.1.tgz", "integrity": "sha512-PsezH1rqdV9VvyNhxxOW32/d75r01NY7TQCmOqomRo15ZSOKbpTFVsfjghxo6JloQUCGnH4k1LGu0R4yCLlWQQ==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/chownr": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/chownr/-/chownr-3.0.0.tgz", "integrity": "sha512-+IxzY9BZOQd/XuYPRmrvEVjF/nqj5kgT4kEq7VofrDoM1MxoRjEWkrCC3EtLi59TVawxTAn+orJwFQcrqEN1+g==", "license": "BlueOak-1.0.0", - "peer": true, "engines": { "node": ">=18" } @@ -1081,7 +1044,6 @@ "resolved": "https://registry.npmjs.org/cli-width/-/cli-width-4.1.0.tgz", "integrity": "sha512-ouuZd4/dm2Sw5Gmqy6bGyNNNe1qt9RpmxveLSO7KcgsTnU7RXfsw+/bukWGo1abgBiMAic068rclZsO4IWmmxQ==", "license": "ISC", - "peer": true, "engines": { "node": ">= 12" } @@ -1179,7 +1141,6 @@ "resolved": "https://registry.npmjs.org/env-paths/-/env-paths-4.0.0.tgz", "integrity": "sha512-pxP8eL2SwwaTRi/KHYwLYXinDs7gL3jxFcBYmEdYfZmZXbaVDvdppd0XBU8qVz03rDfKZMXg1omHCbsJjZrMsw==", "license": "MIT", - "peer": true, "dependencies": { "is-safe-filename": "^0.1.0" }, @@ -1203,15 +1164,13 @@ "version": "3.1.3", "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/fast-glob": { "version": "3.3.3", "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz", "integrity": "sha512-7MptL8U0cqcFdzIzwOTHoilX9x5BrNqye7Z/LuC7kCMRio1EMSyqRK3BEAUD7sXRq4iT4AzTVuZdhgQ2TCvYLg==", "license": "MIT", - "peer": true, "dependencies": { "@nodelib/fs.stat": "^2.0.2", "@nodelib/fs.walk": "^1.2.3", @@ -1227,15 +1186,13 @@ "version": "3.0.3", "resolved": "https://registry.npmjs.org/fast-string-truncated-width/-/fast-string-truncated-width-3.0.3.tgz", "integrity": "sha512-0jjjIEL6+0jag3l2XWWizO64/aZVtpiGE3t0Zgqxv0DPuxiMjvB3M24fCyhZUO4KomJQPj3LTSUnDP3GpdwC0g==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/fast-string-width": { "version": "3.0.2", "resolved": "https://registry.npmjs.org/fast-string-width/-/fast-string-width-3.0.2.tgz", "integrity": "sha512-gX8LrtNEI5hq8DVUfRQMbr5lpaS4nMIWV+7XEbXk2b8kiQIizgnlr12B4dA3ZEx3308ze0O4Q1R+cHts8kyUJg==", "license": "MIT", - "peer": true, "dependencies": { "fast-string-truncated-width": "^3.0.2" } @@ -1254,15 +1211,13 @@ "url": "https://opencollective.com/fastify" } ], - "license": "BSD-3-Clause", - "peer": true + "license": "BSD-3-Clause" }, "node_modules/fast-wrap-ansi": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/fast-wrap-ansi/-/fast-wrap-ansi-0.2.0.tgz", "integrity": "sha512-rLV8JHxTyhVmFYhBJuMujcrHqOT2cnO5Zxj37qROj23CP39GXubJRBUFF0z8KFK77Uc0SukZUf7JZhsVEQ6n8w==", "license": "MIT", - "peer": true, "dependencies": { "fast-string-width": "^3.0.2" } @@ -1272,7 +1227,6 @@ "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.20.1.tgz", "integrity": "sha512-GGToxJ/w1x32s/D2EKND7kTil4n8OVk/9mycTc4VDza13lOvpUZTGX3mFSCtV9ksdGBVzvsyAVLM6mHFThxXxw==", "license": "ISC", - "peer": true, "dependencies": { "reusify": "^1.0.4" } @@ -1282,7 +1236,6 @@ "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "license": "MIT", - "peer": true, "dependencies": { "to-regex-range": "^5.0.1" }, @@ -1304,7 +1257,6 @@ "resolved": "https://registry.npmjs.org/get-east-asian-width/-/get-east-asian-width-1.5.0.tgz", "integrity": "sha512-CQ+bEO+Tva/qlmw24dCejulK5pMzVnUOFOijVogd3KQs07HnRIgp8TGipvCCRT06xeYEbpbgwaCxglFyiuIcmA==", "license": "MIT", - "peer": true, "engines": { "node": ">=18" }, @@ -1317,7 +1269,6 @@ "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", "license": "ISC", - "peer": true, "dependencies": { "is-glob": "^4.0.1" }, @@ -1330,7 +1281,6 @@ "resolved": "https://registry.npmjs.org/globby/-/globby-16.1.1.tgz", "integrity": "sha512-dW7vl+yiAJSp6aCekaVnVJxurRv7DCOLyXqEG3RYMYUg7AuJ2jCqPkZTA8ooqC2vtnkaMcV5WfFBMuEnTu1OQg==", "license": "MIT", - "peer": true, "dependencies": { "@sindresorhus/merge-streams": "^4.0.0", "fast-glob": "^3.3.3", @@ -1377,7 +1327,6 @@ "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.7.2.tgz", "integrity": "sha512-im9DjEDQ55s9fL4EYzOAv0yMqmMBSZp6G0VvFyTMPKWxiSBHUj9NW/qqLmXUwXrrM7AvqSlTCfvqRb0cM8yYqw==", "license": "MIT", - "peer": true, "dependencies": { "safer-buffer": ">= 2.1.2 < 3.0.0" }, @@ -1394,7 +1343,6 @@ "resolved": "https://registry.npmjs.org/ignore/-/ignore-7.0.5.tgz", "integrity": "sha512-Hs59xBNfUIunMFgWAbGX5cq6893IbWg4KnrjbYwX3tx0ztorVgTDA6B2sxf8ejHJ4wz8BqGUMYlnzNBer5NvGg==", "license": "MIT", - "peer": true, "engines": { "node": ">= 4" } @@ -1404,7 +1352,6 @@ "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", "license": "MIT", - "peer": true, "engines": { "node": ">=0.10.0" } @@ -1423,7 +1370,6 @@ "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", "license": "MIT", - "peer": true, "dependencies": { "is-extglob": "^2.1.1" }, @@ -1436,7 +1382,6 @@ "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", "license": "MIT", - "peer": true, "engines": { "node": ">=0.12.0" } @@ -1446,7 +1391,6 @@ "resolved": "https://registry.npmjs.org/is-path-inside/-/is-path-inside-4.0.0.tgz", "integrity": "sha512-lJJV/5dYS+RcL8uQdBDW9c9uWFLLBNRyFhnAKXw5tVqLlKZ4RMGZKv+YQ/IA3OhD+RpbJa1LLFM1FQPGyIXvOA==", "license": "MIT", - "peer": true, "engines": { "node": ">=12" }, @@ -1459,7 +1403,6 @@ "resolved": "https://registry.npmjs.org/is-safe-filename/-/is-safe-filename-0.1.1.tgz", "integrity": "sha512-4SrR7AdnY11LHfDKTZY1u6Ga3RuxZdl3YKWWShO5iyuG5h8QS4GD2tOb04peBJ5I7pXbR+CGBNEhTcwK+FzN3g==", "license": "MIT", - "peer": true, "engines": { "node": ">=20" }, @@ -1472,7 +1415,6 @@ "resolved": "https://registry.npmjs.org/is-unicode-supported/-/is-unicode-supported-2.1.0.tgz", "integrity": "sha512-mE00Gnza5EEB3Ds0HfMyllZzbBrmLOX3vfWoj9A9PEnTfratQ/BcaJOuMhnkhjXvb2+FkY3VuHqtAGpTPmglFQ==", "license": "MIT", - "peer": true, "engines": { "node": ">=18" }, @@ -1484,22 +1426,19 @@ "version": "4.0.0", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", "integrity": "sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/json-schema-traverse": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/merge2": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", "integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==", "license": "MIT", - "peer": true, "engines": { "node": ">= 8" } @@ -1509,7 +1448,6 @@ "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", "license": "MIT", - "peer": true, "dependencies": { "braces": "^3.0.3", "picomatch": "^2.3.1" @@ -1523,7 +1461,6 @@ "resolved": "https://registry.npmjs.org/minipass/-/minipass-7.1.3.tgz", "integrity": "sha512-tEBHqDnIoM/1rXME1zgka9g6Q2lcoCkxHLuc7ODJ5BxbP5d4c2Z5cGgtXAku59200Cx7diuHTOYfSBD8n6mm8A==", "license": "BlueOak-1.0.0", - "peer": true, "engines": { "node": ">=16 || 14 >=14.17" } @@ -1533,7 +1470,6 @@ "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-3.1.0.tgz", "integrity": "sha512-KZxYo1BUkWD2TVFLr0MQoM8vUUigWD3LlD83a/75BqC+4qE0Hb1Vo5v1FgcfaNXvfXzr+5EhQ6ing/CaBijTlw==", "license": "MIT", - "peer": true, "dependencies": { "minipass": "^7.1.2" }, @@ -1552,7 +1488,6 @@ "resolved": "https://registry.npmjs.org/mustache/-/mustache-4.2.0.tgz", "integrity": "sha512-71ippSywq5Yb7/tVYyGbkBggbU8H3u5Rz56fH60jGFgr8uHwxs+aSKeqmluIVzM0m0kB7xQjKS6qPfd0b2ZoqQ==", "license": "MIT", - "peer": true, "bin": { "mustache": "bin/mustache" } @@ -1562,7 +1497,6 @@ "resolved": "https://registry.npmjs.org/mute-stream/-/mute-stream-3.0.0.tgz", "integrity": "sha512-dkEJPVvun4FryqBmZ5KhDo0K9iDXAwn08tMLDinNdRBNPcYEDiWYysLcc6k3mjTMlbP9KyylvRpd4wFtwrT9rw==", "license": "ISC", - "peer": true, "engines": { "node": "^20.17.0 || >=22.9.0" } @@ -1571,15 +1505,13 @@ "version": "1.1.1", "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", - "license": "ISC", - "peer": true + "license": "ISC" }, "node_modules/picomatch": { "version": "2.3.2", "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz", "integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==", "license": "MIT", - "peer": true, "engines": { "node": ">=8.6" }, @@ -1592,7 +1524,6 @@ "resolved": "https://registry.npmjs.org/pluralize/-/pluralize-8.0.0.tgz", "integrity": "sha512-Nc3IT5yHzflTfbjgqWcCPpo7DaKy4FnpB0l/zCAW0Tc7jxAiuqSxHasntB3D7887LSrA93kDJ9IXovxJYxyLCA==", "license": "MIT", - "peer": true, "engines": { "node": ">=4" } @@ -1602,7 +1533,6 @@ "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.8.1.tgz", "integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==", "license": "MIT", - "peer": true, "bin": { "prettier": "bin/prettier.cjs" }, @@ -1640,8 +1570,7 @@ "url": "https://feross.org/support" } ], - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/require-directory": { "version": "2.1.1", @@ -1657,7 +1586,6 @@ "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz", "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==", "license": "MIT", - "peer": true, "engines": { "node": ">=0.10.0" } @@ -1667,7 +1595,6 @@ "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.1.0.tgz", "integrity": "sha512-g6QUff04oZpHs0eG5p83rFLhHeV00ug/Yf9nZM6fLeUrPguBTkTQOdpAWWspMh55TZfVQDPaN3NQJfbVRAxdIw==", "license": "MIT", - "peer": true, "engines": { "iojs": ">=1.0.0", "node": ">=0.10.0" @@ -1692,7 +1619,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "queue-microtask": "^1.2.2" } @@ -1701,15 +1627,13 @@ "version": "2.1.2", "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/semver": { "version": "7.7.4", "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.4.tgz", "integrity": "sha512-vFKC2IEtQnVhpT78h1Yp8wzwrf8CM+MzKMHGJZfBtzhZNycRFnXsHk6E5TxIkkMsgNS7mdX3AGB7x2QM2di4lA==", "license": "ISC", - "peer": true, "bin": { "semver": "bin/semver.js" }, @@ -1722,7 +1646,6 @@ "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz", "integrity": "sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw==", "license": "ISC", - "peer": true, "engines": { "node": ">=14" }, @@ -1752,7 +1675,6 @@ "resolved": "https://registry.npmjs.org/slash/-/slash-5.1.0.tgz", "integrity": "sha512-ZA6oR3T/pEyuqwMgAKT0/hAv8oAXckzbkmR0UkUosQ+Mc4RxGoJkRmwHgHufaenlyAgE1Mxgpdcrf75y6XcnDg==", "license": "MIT", - "peer": true, "engines": { "node": ">=14.16" }, @@ -1812,7 +1734,6 @@ "resolved": "https://registry.npmjs.org/tar/-/tar-7.5.13.tgz", "integrity": "sha512-tOG/7GyXpFevhXVh8jOPJrmtRpOTsYqUIkVdVooZYJS/z8WhfQUX8RJILmeuJNinGAMSu1veBr4asSHFt5/hng==", "license": "BlueOak-1.0.0", - "peer": true, "dependencies": { "@isaacs/fs-minipass": "^4.0.0", "chownr": "^3.0.0", @@ -1829,7 +1750,6 @@ "resolved": "https://registry.npmjs.org/temporal-polyfill/-/temporal-polyfill-0.3.2.tgz", "integrity": "sha512-TzHthD/heRK947GNiSu3Y5gSPpeUDH34+LESnfsq8bqpFhsB79HFBX8+Z834IVX68P3EUyRPZK5bL/1fh437Eg==", "license": "MIT", - "peer": true, "dependencies": { "temporal-spec": "0.3.1" } @@ -1838,15 +1758,13 @@ "version": "0.3.1", "resolved": "https://registry.npmjs.org/temporal-spec/-/temporal-spec-0.3.1.tgz", "integrity": "sha512-B4TUhezh9knfSIMwt7RVggApDRJZo73uZdj8AacL2mZ8RP5KtLianh2MXxL06GN9ESYiIsiuoLQhgVfwe55Yhw==", - "license": "ISC", - "peer": true + "license": "ISC" }, "node_modules/to-regex-range": { "version": "5.0.1", "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", "license": "MIT", - "peer": true, "dependencies": { "is-number": "^7.0.0" }, @@ -1865,7 +1783,6 @@ "resolved": "https://registry.npmjs.org/unicorn-magic/-/unicorn-magic-0.4.0.tgz", "integrity": "sha512-wH590V9VNgYH9g3lH9wWjTrUoKsjLF6sGLjhR4sH1LWpLmCOH0Zf7PukhDA8BiS7KHe4oPNkcTHqYkj7SOGUOw==", "license": "MIT", - "peer": true, "engines": { "node": ">=20" }, @@ -1878,7 +1795,6 @@ "resolved": "https://registry.npmjs.org/vscode-jsonrpc/-/vscode-jsonrpc-8.2.0.tgz", "integrity": "sha512-C+r0eKJUIfiDIfwJhria30+TYWPtuHJXHtI7J0YlOmKAo7ogxP20T0zxB7HZQIFhIyvoBPwWskjxrvAtfjyZfA==", "license": "MIT", - "peer": true, "engines": { "node": ">=14.0.0" } @@ -1888,7 +1804,6 @@ "resolved": "https://registry.npmjs.org/vscode-languageserver/-/vscode-languageserver-9.0.1.tgz", "integrity": "sha512-woByF3PDpkHFUreUa7Hos7+pUWdeWMXRd26+ZX2A8cFx6v/JPTtd4/uN0/jB6XQHYaOlHbio03NTHCqrgG5n7g==", "license": "MIT", - "peer": true, "dependencies": { "vscode-languageserver-protocol": "3.17.5" }, @@ -1901,7 +1816,6 @@ "resolved": "https://registry.npmjs.org/vscode-languageserver-protocol/-/vscode-languageserver-protocol-3.17.5.tgz", "integrity": "sha512-mb1bvRJN8SVznADSGWM9u/b07H7Ecg0I3OgXDuLdn307rl/J3A9YD6/eYOssqhecL27hK1IPZAsaqh00i/Jljg==", "license": "MIT", - "peer": true, "dependencies": { "vscode-jsonrpc": "8.2.0", "vscode-languageserver-types": "3.17.5" @@ -1911,15 +1825,13 @@ "version": "1.0.12", "resolved": "https://registry.npmjs.org/vscode-languageserver-textdocument/-/vscode-languageserver-textdocument-1.0.12.tgz", "integrity": "sha512-cxWNPesCnQCcMPeenjKKsOCKQZ/L6Tv19DTRIGuLWe32lyzWhihGVJ/rcckZXJxfdKCFvRLS3fpBIsV/ZGX4zA==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/vscode-languageserver-types": { "version": "3.17.5", "resolved": "https://registry.npmjs.org/vscode-languageserver-types/-/vscode-languageserver-types-3.17.5.tgz", "integrity": "sha512-Ld1VelNuX9pdF39h2Hgaeb5hEZM2Z3jUrrMgWQAu82jMtZp7p3vJT3BzToKtZI7NgQssZje5o0zryOrhQvzQAg==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/wrap-ansi": { "version": "7.0.0", @@ -1973,7 +1885,6 @@ "resolved": "https://registry.npmjs.org/yallist/-/yallist-5.0.0.tgz", "integrity": "sha512-YgvUTfwqyc7UXVMrB+SImsVYSmTS8X/tSrtdNZMImM+n7+QTriRXyXim0mBrTXNeqzVF0KWGgHPeiyViFFrNDw==", "license": "BlueOak-1.0.0", - "peer": true, "engines": { "node": ">=18" } diff --git a/sdk/ai/azure-ai-finetuning-sessions/CHANGELOG.md b/sdk/ai/azure-ai-finetuning-sessions/CHANGELOG.md new file mode 100644 index 000000000000..b957b2575b48 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/CHANGELOG.md @@ -0,0 +1,7 @@ +# Release History + +## 1.0.0b1 (1970-01-01) + +### Other Changes + + - Initial version \ No newline at end of file diff --git a/sdk/ai/azure-ai-finetuning-sessions/LICENSE b/sdk/ai/azure-ai-finetuning-sessions/LICENSE new file mode 100644 index 000000000000..63447fd8bbbf --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/LICENSE @@ -0,0 +1,21 @@ +Copyright (c) Microsoft Corporation. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/sdk/ai/azure-ai-finetuning-sessions/MANIFEST.in b/sdk/ai/azure-ai-finetuning-sessions/MANIFEST.in new file mode 100644 index 000000000000..04c3af1c8fda --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/MANIFEST.in @@ -0,0 +1,7 @@ +include *.md +include LICENSE +include azure/ai/finetuning_sessions/py.typed +recursive-include tests *.py +recursive-include samples *.py *.md +include azure/__init__.py +include azure/ai/__init__.py diff --git a/sdk/ai/azure-ai-finetuning-sessions/README.md b/sdk/ai/azure-ai-finetuning-sessions/README.md new file mode 100644 index 000000000000..543764f269cf --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/README.md @@ -0,0 +1,78 @@ +# Azure Finetuning Sessions client library for Python + + +## Getting started + +### Install the package + +```bash +python -m pip install azure-ai-finetuning-sessions +``` + +#### Prequisites + +- Python 3.9 or later is required to use this package. +- You need an [Azure subscription][azure_sub] to use this package. +- An existing Azure Finetuning Sessions instance. + +#### Create with an Azure Active Directory Credential +To use an [Azure Active Directory (AAD) token credential][authenticate_with_token], +provide an instance of the desired credential type obtained from the +[azure-identity][azure_identity_credentials] library. + +To authenticate with AAD, you must first [pip][pip] install [`azure-identity`][azure_identity_pip] + +After setup, you can choose which type of [credential][azure_identity_credentials] from azure.identity to use. +As an example, [DefaultAzureCredential][default_azure_credential] can be used to authenticate the client: + +Set the values of the client ID, tenant ID, and client secret of the AAD application as environment variables: +`AZURE_CLIENT_ID`, `AZURE_TENANT_ID`, `AZURE_CLIENT_SECRET` + +Use the returned token credential to authenticate the client: + +```python +>>> from azure.ai.finetuning_sessions import FineTuningSessionClient +>>> from azure.identity import DefaultAzureCredential +>>> client = FineTuningSessionClient(endpoint='', credential=DefaultAzureCredential()) +``` + +## Examples + +```python +>>> from azure.ai.finetuning_sessions import FineTuningSessionClient +>>> from azure.identity import DefaultAzureCredential +>>> from azure.core.exceptions import HttpResponseError + +>>> client = FineTuningSessionClient(endpoint='', credential=DefaultAzureCredential()) +>>> try: + + except HttpResponseError as e: + print('service responds error: {}'.format(e.response.json())) + +``` + +## Contributing + +This project welcomes contributions and suggestions. Most contributions require +you to agree to a Contributor License Agreement (CLA) declaring that you have +the right to, and actually do, grant us the rights to use your contribution. +For details, visit https://cla.microsoft.com. + +When you submit a pull request, a CLA-bot will automatically determine whether +you need to provide a CLA and decorate the PR appropriately (e.g., label, +comment). Simply follow the instructions provided by the bot. You will only +need to do this once across all repos using our CLA. + +This project has adopted the +[Microsoft Open Source Code of Conduct][code_of_conduct]. For more information, +see the Code of Conduct FAQ or contact opencode@microsoft.com with any +additional questions or comments. + + +[code_of_conduct]: https://opensource.microsoft.com/codeofconduct/ +[authenticate_with_token]: https://docs.microsoft.com/azure/cognitive-services/authentication?tabs=powershell#authenticate-with-an-authentication-token +[azure_identity_credentials]: https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/identity/azure-identity#credentials +[azure_identity_pip]: https://pypi.org/project/azure-identity/ +[default_azure_credential]: https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/identity/azure-identity#defaultazurecredential +[pip]: https://pypi.org/project/pip/ +[azure_sub]: https://azure.microsoft.com/free/ diff --git a/sdk/ai/azure-ai-finetuning-sessions/_metadata.json b/sdk/ai/azure-ai-finetuning-sessions/_metadata.json new file mode 100644 index 000000000000..49515fdbafdf --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/_metadata.json @@ -0,0 +1,3 @@ +{ + "apiVersions": {} +} \ No newline at end of file diff --git a/sdk/ai/azure-ai-finetuning-sessions/apiview-properties.json b/sdk/ai/azure-ai-finetuning-sessions/apiview-properties.json new file mode 100644 index 000000000000..ddec01ea26af --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/apiview-properties.json @@ -0,0 +1,72 @@ +{ + "CrossLanguagePackageId": "Azure.AI.FineTuning.Sessions", + "CrossLanguageDefinitionId": { + "azure.ai.finetuning_sessions.models.AdamParams": "Azure.AI.FineTuning.Sessions.AdamParams", + "azure.ai.finetuning_sessions.models.ApiError": "OpenAI.Error", + "azure.ai.finetuning_sessions.models.ApiErrorResponse": "Azure.AI.Projects.ApiErrorResponse", + "azure.ai.finetuning_sessions.models.Checkpoint": "Azure.AI.FineTuning.Sessions.FineTuningCheckpoint", + "azure.ai.finetuning_sessions.models.CheckpointInfo": "Azure.AI.FineTuning.Sessions.FineTuningCheckpointInfo", + "azure.ai.finetuning_sessions.models.CheckpointList": "Azure.AI.FineTuning.Sessions.FineTuningCheckpointList", + "azure.ai.finetuning_sessions.models.CreateSessionRequest": "Azure.AI.FineTuning.Sessions.CreateFineTuningSessionRequest", + "azure.ai.finetuning_sessions.models.Cursor": "Azure.AI.FineTuning.Sessions.FineTuningCursor", + "azure.ai.finetuning_sessions.models.Datum": "Azure.AI.FineTuning.Sessions.FineTuningDatum", + "azure.ai.finetuning_sessions.models.ForwardBackwardInput": "Azure.AI.FineTuning.Sessions.ForwardBackwardInput", + "azure.ai.finetuning_sessions.models.OperationResult": "Azure.AI.FineTuning.Sessions.FineTuningOperationResult", + "azure.ai.finetuning_sessions.models.ForwardBackwardOperationResult": "Azure.AI.FineTuning.Sessions.ForwardBackwardOperationResult", + "azure.ai.finetuning_sessions.models.ForwardBackwardRequest": "Azure.AI.FineTuning.Sessions.ForwardBackwardRequest", + "azure.ai.finetuning_sessions.models.HeartbeatResponse": "Azure.AI.FineTuning.Sessions.FineTuningHeartbeatResponse", + "azure.ai.finetuning_sessions.models.LoRAConfig": "Azure.AI.FineTuning.Sessions.LoRAConfig", + "azure.ai.finetuning_sessions.models.LossFnConfig": "Azure.AI.FineTuning.Sessions.LossFnConfig", + "azure.ai.finetuning_sessions.models.LossFnInputs": "Azure.AI.FineTuning.Sessions.LossFnInputs", + "azure.ai.finetuning_sessions.models.ModelInput": "Azure.AI.FineTuning.Sessions.FineTuningModelInput", + "azure.ai.finetuning_sessions.models.ModelInputChunk": "Azure.AI.FineTuning.Sessions.FineTuningModelInputChunk", + "azure.ai.finetuning_sessions.models.OptimStepOperationResult": "Azure.AI.FineTuning.Sessions.OptimStepOperationResult", + "azure.ai.finetuning_sessions.models.OptimStepRequest": "Azure.AI.FineTuning.Sessions.OptimStepRequest", + "azure.ai.finetuning_sessions.models.SampledSequence": "Azure.AI.FineTuning.Sessions.FineTuningSampledSequence", + "azure.ai.finetuning_sessions.models.SampleOperationResult": "Azure.AI.FineTuning.Sessions.FineTuningSampleOperationResult", + "azure.ai.finetuning_sessions.models.SampleRequest": "Azure.AI.FineTuning.Sessions.FineTuningSampleRequest", + "azure.ai.finetuning_sessions.models.SamplingParams": "Azure.AI.FineTuning.Sessions.FineTuningSamplingParams", + "azure.ai.finetuning_sessions.models.SaveCheckpointOperationResult": "Azure.AI.FineTuning.Sessions.SaveCheckpointOperationResult", + "azure.ai.finetuning_sessions.models.SaveCheckpointRequest": "Azure.AI.FineTuning.Sessions.SaveCheckpointRequest", + "azure.ai.finetuning_sessions.models.SaveSamplerWeightsOperationResult": "Azure.AI.FineTuning.Sessions.SaveSamplerWeightsOperationResult", + "azure.ai.finetuning_sessions.models.SaveSamplerWeightsRequest": "Azure.AI.FineTuning.Sessions.SaveSamplerWeightsRequest", + "azure.ai.finetuning_sessions.models.Session": "Azure.AI.FineTuning.Sessions.FineTuningSession", + "azure.ai.finetuning_sessions.models.SessionList": "Azure.AI.FineTuning.Sessions.FineTuningSessionList", + "azure.ai.finetuning_sessions.models.SessionModelData": "Azure.AI.FineTuning.Sessions.FineTuningSessionModelData", + "azure.ai.finetuning_sessions.models.SessionSummary": "Azure.AI.FineTuning.Sessions.FineTuningSessionSummary", + "azure.ai.finetuning_sessions.models.TensorData": "Azure.AI.FineTuning.Sessions.TensorData", + "azure.ai.finetuning_sessions.models.OperationType": "Azure.AI.FineTuning.Sessions.FineTuningOperationType", + "azure.ai.finetuning_sessions.models.OperationStatus": "Azure.AI.FineTuning.Sessions.FineTuningOperationStatus", + "azure.ai.finetuning_sessions.models.FoundryFeaturesOptInKeys": "Azure.AI.Projects.FoundryFeaturesOptInKeys", + "azure.ai.finetuning_sessions.models.SessionType": "Azure.AI.FineTuning.Sessions.FineTuningSessionType", + "azure.ai.finetuning_sessions.models.SessionStatus": "Azure.AI.FineTuning.Sessions.FineTuningSessionStatus", + "azure.ai.finetuning_sessions.models.LossFn": "Azure.AI.FineTuning.Sessions.FineTuningLossFn", + "azure.ai.finetuning_sessions.models.CheckpointType": "Azure.AI.FineTuning.Sessions.FineTuningCheckpointType", + "azure.ai.finetuning_sessions.operations.SessionsOperations.begin_create": "Azure.AI.FineTuning.Sessions.FineTuningSessions.create", + "azure.ai.finetuning_sessions.aio.operations.SessionsOperations.begin_create": "Azure.AI.FineTuning.Sessions.FineTuningSessions.create", + "azure.ai.finetuning_sessions.operations.SessionsOperations.list": "Azure.AI.FineTuning.Sessions.FineTuningSessions.list", + "azure.ai.finetuning_sessions.aio.operations.SessionsOperations.list": "Azure.AI.FineTuning.Sessions.FineTuningSessions.list", + "azure.ai.finetuning_sessions.operations.SessionsOperations.get": "Azure.AI.FineTuning.Sessions.FineTuningSessions.get", + "azure.ai.finetuning_sessions.aio.operations.SessionsOperations.get": "Azure.AI.FineTuning.Sessions.FineTuningSessions.get", + "azure.ai.finetuning_sessions.operations.SessionsOperations.begin_unload": "Azure.AI.FineTuning.Sessions.FineTuningSessions.unload", + "azure.ai.finetuning_sessions.aio.operations.SessionsOperations.begin_unload": "Azure.AI.FineTuning.Sessions.FineTuningSessions.unload", + "azure.ai.finetuning_sessions.operations.SessionsOperations.heartbeat": "Azure.AI.FineTuning.Sessions.FineTuningSessions.heartbeat", + "azure.ai.finetuning_sessions.aio.operations.SessionsOperations.heartbeat": "Azure.AI.FineTuning.Sessions.FineTuningSessions.heartbeat", + "azure.ai.finetuning_sessions.operations.TrainingOperations.begin_forward_backward": "Azure.AI.FineTuning.Sessions.Training.forwardBackward", + "azure.ai.finetuning_sessions.aio.operations.TrainingOperations.begin_forward_backward": "Azure.AI.FineTuning.Sessions.Training.forwardBackward", + "azure.ai.finetuning_sessions.operations.TrainingOperations.begin_optim_step": "Azure.AI.FineTuning.Sessions.Training.optimStep", + "azure.ai.finetuning_sessions.aio.operations.TrainingOperations.begin_optim_step": "Azure.AI.FineTuning.Sessions.Training.optimStep", + "azure.ai.finetuning_sessions.operations.CheckpointsOperations.begin_save": "Azure.AI.FineTuning.Sessions.Checkpoints.save", + "azure.ai.finetuning_sessions.aio.operations.CheckpointsOperations.begin_save": "Azure.AI.FineTuning.Sessions.Checkpoints.save", + "azure.ai.finetuning_sessions.operations.CheckpointsOperations.begin_save_sampler_weights": "Azure.AI.FineTuning.Sessions.Checkpoints.saveSamplerWeights", + "azure.ai.finetuning_sessions.aio.operations.CheckpointsOperations.begin_save_sampler_weights": "Azure.AI.FineTuning.Sessions.Checkpoints.saveSamplerWeights", + "azure.ai.finetuning_sessions.operations.CheckpointsOperations.list": "Azure.AI.FineTuning.Sessions.Checkpoints.list", + "azure.ai.finetuning_sessions.aio.operations.CheckpointsOperations.list": "Azure.AI.FineTuning.Sessions.Checkpoints.list", + "azure.ai.finetuning_sessions.operations.CheckpointsOperations.get": "Azure.AI.FineTuning.Sessions.Checkpoints.get", + "azure.ai.finetuning_sessions.aio.operations.CheckpointsOperations.get": "Azure.AI.FineTuning.Sessions.Checkpoints.get", + "azure.ai.finetuning_sessions.operations.SamplingOperations.begin_sample": "Azure.AI.FineTuning.Sessions.Sampling.sample", + "azure.ai.finetuning_sessions.aio.operations.SamplingOperations.begin_sample": "Azure.AI.FineTuning.Sessions.Sampling.sample", + "azure.ai.finetuning_sessions.operations.Operations.get": "Azure.AI.FineTuning.Sessions.Operations.get", + "azure.ai.finetuning_sessions.aio.operations.Operations.get": "Azure.AI.FineTuning.Sessions.Operations.get" + } +} \ No newline at end of file diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/__init__.py b/sdk/ai/azure-ai-finetuning-sessions/azure/__init__.py new file mode 100644 index 000000000000..d55ccad1f573 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/__init__.py @@ -0,0 +1 @@ +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/__init__.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/__init__.py new file mode 100644 index 000000000000..d55ccad1f573 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/__init__.py @@ -0,0 +1 @@ +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/__init__.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/__init__.py new file mode 100644 index 000000000000..b3670902c57f --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/__init__.py @@ -0,0 +1,58 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import + +from ._client import FineTuningSessionClient # type: ignore +from ._exceptions import ( + BatchTooLargeError, + ContentionError, + EngineDeadError, + MalformedDatumError, + TrainingEngineError, + FineTuningSessionsError, + RequestValidationError, + NoCapacityError, +) +from ._logging_setup import install_default_logging as _install_default_logging +from ._version import VERSION + +__version__ = VERSION + +# Prepend a UTC timestamp to SDK warnings when no logging handler is +# configured to render one (i.e. records would fall through to +# logging.lastResort). No-op in any environment with a real handler. +# Idempotent; installs no handler. +_install_default_logging() + +try: + from ._patch import __all__ as _patch_all + from ._patch import * +except ImportError: + _patch_all = [] +from ._patch import patch_sdk as _patch_sdk + +__all__ = [ + "FineTuningSessionClient", + "FineTuningSession", + "FineTuningSessionsError", + "BatchTooLargeError", + "NoCapacityError", + "TrainingEngineError", + "EngineDeadError", + "ContentionError", + "RequestValidationError", + "MalformedDatumError", +] +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore + +_patch_sdk() diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_client.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_client.py new file mode 100644 index 000000000000..d820cfd7ac31 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_client.py @@ -0,0 +1,116 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from copy import deepcopy +from typing import Any, TYPE_CHECKING +from typing_extensions import Self + +from azure.core import PipelineClient +from azure.core.pipeline import policies +from azure.core.rest import HttpRequest, HttpResponse + +from ._configuration import FineTuningSessionClientConfiguration +from ._utils.serialization import Deserializer, Serializer +from .operations import CheckpointsOperations, Operations, SamplingOperations, SessionsOperations, TrainingOperations + +if TYPE_CHECKING: + from azure.core.credentials import TokenCredential + + +class FineTuningSessionClient: # pylint: disable=client-accepts-api-version-keyword + """FineTuningSessionClient. + + :ivar sessions: SessionsOperations operations + :vartype sessions: azure.ai.finetuning_sessions.operations.SessionsOperations + :ivar training: TrainingOperations operations + :vartype training: azure.ai.finetuning_sessions.operations.TrainingOperations + :ivar checkpoints: CheckpointsOperations operations + :vartype checkpoints: azure.ai.finetuning_sessions.operations.CheckpointsOperations + :ivar sampling: SamplingOperations operations + :vartype sampling: azure.ai.finetuning_sessions.operations.SamplingOperations + :ivar operations: Operations operations + :vartype operations: azure.ai.finetuning_sessions.operations.Operations + :param endpoint: Foundry Project endpoint in the form + "https://{ai-services-account-name}.services.ai.azure.com/api/projects/{project-name}". If you + only have one Project in your Foundry Hub, or to target the default Project in your Hub, use + the form "https://{ai-services-account-name}.services.ai.azure.com/api/projects/_project". + Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Required. + :type credential: ~azure.core.credentials.TokenCredential + :keyword int polling_interval: Default waiting time between two polls for LRO operations if no + Retry-After header is present. + """ + + def __init__(self, endpoint: str, credential: "TokenCredential", *, allow_insecure_http: bool = False, **kwargs: Any) -> None: + _endpoint = "{endpoint}" + self._config = FineTuningSessionClientConfiguration(endpoint=endpoint, credential=credential, allow_insecure_http=allow_insecure_http, **kwargs) + + _policies = kwargs.pop("policies", None) + if _policies is None: + _policies = [ + policies.RequestIdPolicy(**kwargs), + self._config.headers_policy, + self._config.user_agent_policy, + self._config.proxy_policy, + policies.ContentDecodePolicy(**kwargs), + self._config.redirect_policy, + self._config.retry_policy, + self._config.authentication_policy, + self._config.custom_hook_policy, + self._config.logging_policy, + policies.DistributedTracingPolicy(**kwargs), + policies.SensitiveHeaderCleanupPolicy(**kwargs) if self._config.redirect_policy else None, + self._config.http_logging_policy, + ] + self._client: PipelineClient = PipelineClient(base_url=_endpoint, policies=_policies, **kwargs) + + self._serialize = Serializer() + self._deserialize = Deserializer() + self._serialize.client_side_validation = False + self.sessions = SessionsOperations(self._client, self._config, self._serialize, self._deserialize) + self.training = TrainingOperations(self._client, self._config, self._serialize, self._deserialize) + self.checkpoints = CheckpointsOperations(self._client, self._config, self._serialize, self._deserialize) + self.sampling = SamplingOperations(self._client, self._config, self._serialize, self._deserialize) + self.operations = Operations(self._client, self._config, self._serialize, self._deserialize) + + def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: + """Runs the network request through the client's chained policies. + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest("GET", "https://www.example.org/") + + >>> response = client.send_request(request) + + + For more information on this code flow, see https://aka.ms/azsdk/dpcodegen/python/send_request + + :param request: The network request you want to make. Required. + :type request: ~azure.core.rest.HttpRequest + :keyword bool stream: Whether the response payload will be streamed. Defaults to False. + :return: The response of your network call. Does not do error handling on your response. + :rtype: ~azure.core.rest.HttpResponse + """ + + request_copy = deepcopy(request) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + request_copy.url = self._client.format_url(request_copy.url, **path_format_arguments) + return self._client.send_request(request_copy, stream=stream, **kwargs) # type: ignore + + def close(self) -> None: + self._client.close() + + def __enter__(self) -> Self: + self._client.__enter__() + return self + + def __exit__(self, *exc_details: Any) -> None: + self._client.__exit__(*exc_details) diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_configuration.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_configuration.py new file mode 100644 index 000000000000..f0d058726816 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_configuration.py @@ -0,0 +1,79 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from typing import Any, TYPE_CHECKING + +from azure.core.credentials import AzureKeyCredential +from azure.core.pipeline import policies +from azure.core.pipeline.policies import AzureKeyCredentialPolicy + +from ._version import VERSION + +if TYPE_CHECKING: + from azure.core.credentials import TokenCredential + + +class _InsecureBearerTokenCredentialPolicy(policies.BearerTokenCredentialPolicy): + """BearerTokenCredentialPolicy that skips HTTPS enforcement (for local/http:// dev).""" + + def on_request(self, request: Any) -> None: + request.context.options["enforce_https"] = False + super().on_request(request) + + +class FineTuningSessionClientConfiguration: # pylint: disable=too-many-instance-attributes + """Configuration for FineTuningSessionClient. + + Note that all parameters used to create this instance are saved as instance + attributes. + + :param endpoint: Foundry Project endpoint in the form + "https://{ai-services-account-name}.services.ai.azure.com/api/projects/{project-name}". If you + only have one Project in your Foundry Hub, or to target the default Project in your Hub, use + the form "https://{ai-services-account-name}.services.ai.azure.com/api/projects/_project". + Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Required. + :type credential: ~azure.core.credentials.TokenCredential + """ + + def __init__(self, endpoint: str, credential: "TokenCredential", *, allow_insecure_http: bool = False, **kwargs: Any) -> None: + if endpoint is None: + raise ValueError("Parameter 'endpoint' must not be None.") + if credential is None: + raise ValueError("Parameter 'credential' must not be None.") + + self.endpoint = endpoint + self.credential = credential + self.allow_insecure_http = allow_insecure_http + self.api_version = kwargs.pop("api_version", "v1") + self.credential_scopes = kwargs.pop("credential_scopes", ["https://ai.azure.com/.default"]) + kwargs.setdefault("sdk_moniker", "finetuning-sessions/{}".format(VERSION)) + self.polling_interval = kwargs.get("polling_interval", 30) + self._configure(**kwargs) + + def _configure(self, **kwargs: Any) -> None: + self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs) + self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs) + self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs) + self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs) + self.http_logging_policy = kwargs.get("http_logging_policy") or policies.HttpLoggingPolicy(**kwargs) + self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs) + self.redirect_policy = kwargs.get("redirect_policy") or policies.RedirectPolicy(**kwargs) + self.retry_policy = kwargs.get("retry_policy") or policies.RetryPolicy(**kwargs) + self.authentication_policy = kwargs.get("authentication_policy") + if self.credential and not self.authentication_policy: + if isinstance(self.credential, AzureKeyCredential): + # API key auth: sends "api-key: " header on every request. + self.authentication_policy = AzureKeyCredentialPolicy(self.credential, name="api-key") + else: + # Token (OAuth2) auth — enforce HTTPS unless running against http://. + policy_cls = _InsecureBearerTokenCredentialPolicy if self.allow_insecure_http else policies.BearerTokenCredentialPolicy + self.authentication_policy = policy_cls( + self.credential, *self.credential_scopes, **kwargs + ) diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_exceptions.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_exceptions.py new file mode 100644 index 000000000000..06599c760dd3 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_exceptions.py @@ -0,0 +1,373 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------- +"""Typed exceptions for actionable failure modes. + +Customers should branch on exception type rather than grepping message strings: + + from azure.ai.finetuning_sessions import ( + BatchTooLargeError, + NoCapacityError, + TrainingEngineError, + ContentionError, + RequestValidationError, + ) + + try: + result = session.forward_backward(batch, loss_fn="cross_entropy") + except BatchTooLargeError as e: + # Split batch and retry + ... + except NoCapacityError as e: + # Wait e.retry_after_sec, or switch project + ... + except TrainingEngineError as e: + # Stop the run — model weights are lost + ... + +Each exception carries structured metadata extracted from the server's response +body so callers can make decisions without string parsing. +""" +from __future__ import annotations + +from typing import Any, Optional + +from azure.core.exceptions import HttpResponseError + + +class FineTuningSessionsError(HttpResponseError): + """Base class for all typed SDK exceptions. + + Inherits from ``azure.core.exceptions.HttpResponseError`` so existing + ``except HttpResponseError`` handlers still catch these. + """ + + def __init__(self, message: str, *, response: Any = None, **kwargs: Any) -> None: + super().__init__(message=message, response=response, **kwargs) + + +class BatchTooLargeError(FineTuningSessionsError): + """The batch exceeded the server's size limit. + + Action: split the batch into smaller chunks and retry. + + Attributes: + max_batch_size: The maximum batch size the server accepts (if reported). + actual_batch_size: The batch size that was rejected (if reported). + """ + + def __init__( + self, + message: str, + *, + max_batch_size: Optional[int] = None, + actual_batch_size: Optional[int] = None, + response: Any = None, + **kwargs: Any, + ) -> None: + super().__init__(message, response=response, **kwargs) + self.max_batch_size = max_batch_size + self.actual_batch_size = actual_batch_size + + +class NoCapacityError(FineTuningSessionsError): + """No engine capacity is currently available. + + Action: wait ``retry_after_sec`` seconds and retry, or switch to a + different project/endpoint. + + Attributes: + retry_after_sec: Suggested wait time in seconds before retrying. + reason: Server-reported reason string (e.g. ``"engine_busy"``). + """ + + def __init__( + self, + message: str, + *, + retry_after_sec: Optional[float] = None, + reason: Optional[str] = None, + response: Any = None, + **kwargs: Any, + ) -> None: + super().__init__(message, response=response, **kwargs) + self.retry_after_sec = retry_after_sec + self.reason = reason + + +class TrainingEngineError(FineTuningSessionsError): + """The engine serving this session has died. + + Action: stop the training run, alert on-call. LoRA weights in VRAM are + lost. Do NOT retry on the same session — create a new one (optionally + from the last checkpoint). + + Attributes: + session_id: The session that was being served. + error_code: Server error code (e.g. ``"worker_crashed"``). + debug_ref: Opaque reference for support tickets. + """ + + def __init__( + self, + message: str, + *, + session_id: Optional[str] = None, + error_code: Optional[str] = None, + debug_ref: Optional[str] = None, + response: Any = None, + **kwargs: Any, + ) -> None: + super().__init__(message, response=response, **kwargs) + self.session_id = session_id + self.error_code = error_code + self.debug_ref = debug_ref + + +class ContentionError(FineTuningSessionsError): + """The engine is temporarily contended (busy with other tenants). + + Action: back off with exponential delay. Do NOT retry immediately. + + Attributes: + retry_after_sec: Suggested wait time before retrying. + reason: Server-reported reason string. + """ + + def __init__( + self, + message: str, + *, + retry_after_sec: Optional[float] = None, + reason: Optional[str] = None, + response: Any = None, + **kwargs: Any, + ) -> None: + super().__init__(message, response=response, **kwargs) + self.retry_after_sec = retry_after_sec + self.reason = reason + + +class RequestValidationError(FineTuningSessionsError): + """One or more datums in the batch were rejected as invalid. + + Action: this is terminal for the affected datums — fix the data. + + Attributes: + field: The field that failed validation (e.g. ``"forward_backward_input.data"``). + error_code: Server error code (e.g. ``"invalid_request"``). + debug_ref: Opaque reference for support tickets. + """ + + def __init__( + self, + message: str, + *, + field: Optional[str] = None, + error_code: Optional[str] = None, + debug_ref: Optional[str] = None, + response: Any = None, + **kwargs: Any, + ) -> None: + super().__init__(message, response=response, **kwargs) + self.field = field + self.error_code = error_code + self.debug_ref = debug_ref + + +def _classify_http_error( + status_code: int, + body: Optional[dict], + *, + response: Any = None, + session_id: Optional[str] = None, +) -> Optional[FineTuningSessionsError]: + """Attempt to classify an HTTP error response into a typed exception. + + Returns ``None`` if the error doesn't match any known pattern (caller + should fall through to generic error handling). + """ + if body is None: + body = {} + + # --- HTTP 413: Batch too large --- + if status_code == 413: + msg = body.get("message") or body.get("detail") or "Batch too large" + field = body.get("field") + # Try to extract numbers from the message + max_size = None + actual_size = None + if isinstance(msg, str): + import re + # "Batch size (N) exceeds the maximum allowed (M)" + m = re.search(r"Batch size \((\d+)\) exceeds the maximum allowed \((\d+)\)", msg) + if m: + actual_size = int(m.group(1)) + max_size = int(m.group(2)) + if field and "data" in field: + return BatchTooLargeError( + msg, + max_batch_size=max_size, + actual_batch_size=actual_size, + response=response, + ) + # 413 for metadata is not a batch error — return None to let generic handling take over + return BatchTooLargeError( + msg, max_batch_size=max_size, actual_batch_size=actual_size, response=response + ) + + # --- HTTP 503: No capacity / contention --- + if status_code == 503: + reason = body.get("reason", "") + msg = body.get("message") or body.get("detail") or "No available engine capacity" + retry_after: Optional[float] = None + if body.get("retry_after_sec") is not None: + retry_after = float(body["retry_after_sec"]) + + # If the body is a plain string (legacy format), extract from detail + if isinstance(msg, str) and ("capacity" in msg.lower() or "no engine" in msg.lower()): + return NoCapacityError( + msg, retry_after_sec=retry_after, reason=reason or "engine_busy", response=response + ) + if reason == "engine_busy": + return NoCapacityError( + msg, retry_after_sec=retry_after, reason=reason, response=response + ) + # Generic 503 — treat as contention + return ContentionError( + msg, retry_after_sec=retry_after, reason=reason or None, response=response + ) + + # --- HTTP 500: Engine dead / worker crashed / capacity exhaustion --- + if status_code == 500: + # The body may be a flat dict or nested under "detail". + detail = body.get("detail") if isinstance(body.get("detail"), dict) else None + effective = detail or body + msg = ( + effective.get("error_message") + or effective.get("message") + or body.get("detail") + or body.get("error") + or "Internal server error" + ) + error_type = effective.get("type", "") + error_code = effective.get("error_code") or effective.get("code") + debug_ref = effective.get("debug_ref") + + # Capacity exhaustion: server returns type="internal_model_error" with + # a message mentioning capacity. Should really be a 503, + # but older servers return 500. + if isinstance(msg, str) and ("capacity" in msg.lower() or "no engine" in msg.lower()): + return NoCapacityError( + msg, + retry_after_sec=None, + reason=error_type or "no_capacity", + response=response, + ) + + if error_code in ("worker_crashed", "engine_oom", "engine_timeout"): + return TrainingEngineError( + msg, + session_id=session_id, + error_code=error_code, + debug_ref=debug_ref, + response=response, + ) + # Check message heuristics for legacy plain-string responses + if isinstance(msg, str) and any( + kw in msg.lower() for kw in ("engine", "dead", "crashed", "died") + ): + return TrainingEngineError( + msg, + session_id=session_id, + error_code=error_code, + debug_ref=debug_ref, + response=response, + ) + return None # Unknown 500 — don't classify + + # --- HTTP 400/422: Malformed datum / invalid request --- + if status_code in (400, 422): + error_type = body.get("type", "") + msg = body.get("message") or body.get("detail") or "Invalid request" + field = body.get("field") + error_code = body.get("error_code") or body.get("code") + debug_ref = body.get("debug_ref") + + if error_type == "validation_error" or error_code == "invalid_request": + return RequestValidationError( + msg, + field=field, + error_code=error_code or "invalid_request", + debug_ref=debug_ref, + response=response, + ) + return None + + return None + + +def _classify_poll_failure( + envelope: dict, + *, + session_id: Optional[str] = None, +) -> Optional[FineTuningSessionsError]: + """Classify a failed poll-endpoint envelope into a typed exception. + + The poll endpoint returns ``{"status": "failed", "error": "...", "error_code": "...", ...}`` + when a GPU operation fails. This function translates known error codes into typed exceptions. + + Returns ``None`` if the failure doesn't match any known pattern. + """ + error_code = envelope.get("error_code") or envelope.get("code") + error_msg = envelope.get("error") or "Operation failed" + debug_ref = envelope.get("debug_ref") + + if error_code == "engine_oom": + return BatchTooLargeError( + error_msg + " (Try a smaller batch or shorter sequences.)", + response=None, + ) + + if error_code in ("worker_crashed", "engine_timeout"): + return TrainingEngineError( + error_msg, + session_id=session_id, + error_code=error_code, + debug_ref=debug_ref, + ) + + if error_code == "invalid_request": + return RequestValidationError( + error_msg, + error_code=error_code, + debug_ref=debug_ref, + ) + + if error_code == "model_not_found": + return TrainingEngineError( + error_msg, + session_id=session_id, + error_code=error_code, + debug_ref=debug_ref, + ) + + return None + + +# Backward-compat aliases. +EngineDeadError = TrainingEngineError +MalformedDatumError = RequestValidationError + +__all__ = [ + "FineTuningSessionsError", + "BatchTooLargeError", + "NoCapacityError", + "TrainingEngineError", + "EngineDeadError", + "ContentionError", + "RequestValidationError", + "MalformedDatumError", +] diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_logging_setup.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_logging_setup.py new file mode 100644 index 000000000000..bcab287ba2f6 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_logging_setup.py @@ -0,0 +1,138 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------- +"""Prepend a UTC timestamp to SDK log records when no handler will render one. + +Scope: enriches the SDK's intentional, hand-written telemetry (HTTP +traces, session lifecycle, heartbeat, crash warnings) on these loggers: + +* ``azure.ai.finetuning_sessions._patch`` +* ``azure.ai.finetuning_sessions.aio._patch`` + +Autorest / code-generator helpers under ``_utils/`` are deliberately +excluded -- their logs are plumbing noise that operators do not correlate. + +Mechanism: a :class:`logging.Filter` attached to each emitting logger. +For each record, the filter walks the logger's parent chain (honoring +``propagate=False``) and asks "is any handler reachable?". If **no** +handler is reachable, the record will fall through to +``logging.lastResort`` -- whose hardcoded format string is +``"%(levelname)s:%(name)s:%(message)s"`` -- and thus show no timestamp. +In that case the filter prepends ``[] `` to ``record.msg`` +so the rendered line includes a timestamp. If a handler **is** reachable, +the filter is a no-op: the caller's formatter is responsible for the +timestamp (via ``%(asctime)s``, ``record.created``, or JSON time-field +handling), and the SDK does not duplicate it. + +The timestamp uses ``record.created`` (set automatically by Python at log +call time), not the time the filter runs -- preserving event time under +queued / async handlers. + +This module installs **no handler**, does not change propagation, and +never causes duplicate log lines in any caller-configured setup. + +Operators can opt out without code changes by setting the env var +``AZURE_AI_FINETUNING_SESSIONS_SDK_LOG_CONTEXT`` to any of ``0``, +``false``, ``no``, ``off``, ``disable``, ``disabled`` (case-insensitive). +Programmatic control is also available via +``install_default_logging(enabled=True)`` or ``enabled=False``. +""" +from __future__ import annotations + +import logging as _logging +import os as _os +from datetime import datetime as _datetime +from datetime import timezone as _timezone +from typing import Optional, Tuple + +_SDK_ROOT = "azure.ai.finetuning_sessions" + +# Environment variable that lets operators opt OUT of the timestamp +# filter without touching code. Default is enabled. Recognized "falsey" +# values (case-insensitive): "0", "false", "no", "off", "disable", +# "disabled". Any other value -- including unset -- leaves the filter +# enabled. +_ENV_VAR = "AZURE_AI_FINETUNING_SESSIONS_SDK_LOG_CONTEXT" +_FALSEY = frozenset({"0", "false", "no", "off", "disable", "disabled"}) + +# Child loggers that emit the SDK's intentional, user-facing telemetry. +# When a NEW hand-written SDK module starts emitting telemetry that +# should benefit from the no-handler timestamp fallback, add its dotted +# name here. +_SDK_EMITTING_LOGGERS: Tuple[str, ...] = ( + f"{_SDK_ROOT}._patch", + f"{_SDK_ROOT}.aio._patch", +) + + +def _enabled_from_env() -> bool: + raw = _os.environ.get(_ENV_VAR) + if raw is None: + return True + return raw.strip().lower() not in _FALSEY + + +def _has_any_handler(name: str) -> bool: + """True iff at least one handler is reachable in the logger's chain. + + Walks parents until a handler is found or propagation breaks. Mirrors + the lookup ``Logger.callHandlers`` does, so a ``False`` return means + the record will be dispatched to ``logging.lastResort``. + """ + c: Optional[_logging.Logger] = _logging.getLogger(name) + while c is not None: + if c.handlers: + return True + if not c.propagate: + return False + c = c.parent + return False + + +class _SdkTimestampFilter(_logging.Filter): + """Prepend an ISO-8601 UTC timestamp to ``record.msg`` only when no + handler is configured to render the record's time. Always returns + ``True`` -- enriches, never drops.""" + + def filter(self, record: _logging.LogRecord) -> bool: + if not _has_any_handler(record.name): + ts = _datetime.fromtimestamp( + record.created, _timezone.utc + ).isoformat(timespec="milliseconds") + record.msg = f"[{ts}] {record.msg}" + return True + + +def install_default_logging(enabled: Optional[bool] = None) -> None: + """Attach the timestamp filter to every SDK emitting logger. + + :keyword enabled: If ``None`` (default), enablement is read from the + ``AZURE_AI_FINETUNING_SESSIONS_SDK_LOG_CONTEXT`` env var; the + filter is on unless the var is set to a falsey value (``0``, + ``false``, ``no``, ``off``, ``disable``, ``disabled``). Pass + ``True`` / ``False`` to force-enable or force-disable + programmatically. + + Idempotent: a second call with ``enabled=True`` does not duplicate + the filter; a call with ``enabled=False`` removes any previously + installed instance. Installs no handler and does not change logger + propagation. + """ + if enabled is None: + enabled = _enabled_from_env() + + for name in _SDK_EMITTING_LOGGERS: + logger = _logging.getLogger(name) + existing = [f for f in logger.filters if isinstance(f, _SdkTimestampFilter)] + if not enabled: + for f in existing: + logger.removeFilter(f) + continue + if existing: + continue + logger.addFilter(_SdkTimestampFilter()) + + +__all__ = ["install_default_logging"] diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_patch.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_patch.py new file mode 100644 index 000000000000..e3d067106e10 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_patch.py @@ -0,0 +1,1322 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------- +"""Handwritten convenience layer on top of the generated FineTuningSessionClient. + +``FineTuningSession`` wraps a live session and exposes the hero-code API from +SPEC_FOUNDRY_AICLIENT.md: + + session = FineTuningSession(client, session_id="session_xxx") + fb_result = session.forward_backward(batch, loss_fn="cross_entropy") + opt_result = session.optim_step(AdamParams(learning_rate=1e-4)) + ckpt_result = session.save_weights("my_checkpoint") + sampler_result = session.save_weights_for_sampler(seq_id=0) + sample_result = session.sample(prompt_tokens, sampling_params, num_samples=4) + session.close() + +Each mutating method follows loom's two-step protocol: + 1. POST to the action endpoint — loom returns **200** with + ``{request_id, session_id, status: "pending"}``. + 2. GET ``/fine_tuning/sessions/{sessionId}/request/{requestId}`` — the server + long-polls (up to 5 minutes) and returns the typed result when the GPU finishes. + +Note: the generated ``begin_*`` methods on sub-clients use the Azure LRO (202 + +Operation-Location header) pattern and will **not** work against loom, which returns +200. Always use ``FineTuningSession`` methods for training operations. +""" +from __future__ import annotations + +import concurrent.futures as _futures +import json as _json +import logging as _logging +import os as _os +import threading as _threading +import time as _time +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union + +from azure.core import PipelineClient +from azure.core.exceptions import HttpResponseError as _HttpResponseError +from azure.core.exceptions import ServiceRequestError as _ServiceRequestError +from azure.core.exceptions import ServiceResponseError as _ServiceResponseError +from azure.core.pipeline import policies +from azure.core.rest import HttpRequest as _HttpRequest + +from ._exceptions import ( + _classify_http_error, + _classify_poll_failure, +) + +from .models import ( + AdamParams, + CreateSessionRequest, + Datum, + ForwardBackwardInput, + ForwardBackwardOperationResult, + ForwardBackwardRequest, + FromCheckpoint, + LoRAConfig, + LossFn, + LossFnConfig, + ModelInput, + ModelInputChunk, + OperationResult, + OperationType, + OptimStepRequest, + SampleRequest, + SamplingParams, + SaveCheckpointRequest, + SaveSamplerWeightsRequest, + TensorData, + FoundryFeaturesOptInKeys, +) +from ._client import FineTuningSessionClient as FineTuningSessionClientGenerated +from ._utils.model_base import SdkJSONEncoder as _SdkJSONEncoder, _deserialize as _deserialize_model + +# ── Loom wire-format → OperationResult discriminator map ───────────────────── +# Maps the last path segment of a Loom action URL to the SDK's "type" value. +# NOTE: "forward" maps to "forward_backward" because there is no dedicated +# ForwardOperationResult class yet. This means ForwardBackwardOperationResult +# must tolerate missing fields (total_loss, metrics) — do NOT add required +# fields to that class without also providing a separate ForwardOperationResult. +_LOOM_SUBPATH_TO_OP_TYPE: dict[str, str] = { + "forward_backward": "forward_backward", + "forward": "forward_backward", + "optim_step": "optim_step", + "checkpoint": "save_checkpoint", + "checkpoint_sample": "save_sampler_weights", + "sample": "sample", +} + +# --------------------------------------------------------------------------- +# Chunked forward_backward helpers +# --------------------------------------------------------------------------- + +#: Maximum number of datums in a single forward_backward HTTP request. +_MAX_CHUNK_LEN = 1024 + +#: Approximate maximum payload size (bytes) for a single request. +_MAX_CHUNK_BYTES = 5_000_000 + + +def _estimate_bytes_count(datum: Datum) -> int: + """Estimate the serialised size of a single Datum.""" + size = 0 + # Model input chunks — each token ID ≈ 10 bytes when JSON-serialised. + for chunk in datum.model_input.chunks: + size += len(chunk.tokens) * 10 + # Loss function inputs — each TensorData field's data list × 10. + lfi = datum.loss_fn_inputs + for field_name in ("target_tokens", "weights", "advantages", "logprobs"): + td = getattr(lfi, field_name, None) + if td is not None and hasattr(td, "data") and td.data is not None: + size += len(td.data) * 10 + return size + + +def _chunk_data(data: List[Datum]) -> List[List[Datum]]: + """Split Datum list into chunks respecting size limits.""" + chunks: List[List[Datum]] = [] + current: List[Datum] = [] + current_bytes = 0 + for datum in data: + est = _estimate_bytes_count(datum) + if ( + len(current) > 0 + and current_bytes + est > _MAX_CHUNK_BYTES + ) or len(current) == _MAX_CHUNK_LEN: + chunks.append(current) + current = [] + current_bytes = 0 + current.append(datum) + current_bytes += est + if current: + chunks.append(current) + return chunks + + +# ── Metric reduction ────────────────────────────────────────────────────── + + +def _reduce_mean(xs: List[float], weights: Optional[List[int]] = None) -> float: + if weights is None or sum(weights) == 0: + return sum(xs) / len(xs) if xs else 0.0 + total = sum(x * w for x, w in zip(xs, weights)) + return total / sum(weights) + + +def _reduce_sum(xs: List[float]) -> float: + return sum(xs) + + +def _reduce_min(xs: List[float]) -> float: + return min(xs) + + +def _reduce_max(xs: List[float]) -> float: + return max(xs) + + +def _reduce_slack(xs: List[float], weights: Optional[List[int]] = None) -> float: + return max(xs) - _reduce_mean(xs, weights) + + +def _order_insensitive_hash(xs: list) -> int: + """Order-insensitive hash for metric deduplication.""" + if xs and isinstance(xs[0], set): + return hash(tuple(sorted([y for x in xs for y in x]))) + return hash(tuple(sorted(int(x) for x in xs))) + + +_REDUCE_MAP = { + "mean": _reduce_mean, + "sum": _reduce_sum, + "min": _reduce_min, + "max": _reduce_max, + "slack": _reduce_slack, + "hash_unordered": _order_insensitive_hash, + "unique": lambda xs: xs, +} + + +def _metrics_reduction( + results: List[ForwardBackwardOperationResult], + chunk_sizes: List[int], +) -> dict: + """Reduce metrics across chunked forward_backward results. + + Uses ``chunk_sizes`` (number of datums per chunk) as weights. + """ + if not results: + return {} + # `forward` route returns a base ``OperationResult`` with no `metrics` + # field; tolerate that by using getattr throughout. + # TODO(forward-route): add a proper ``ForwardOperationResult`` subclass + # to ``models/_models.py`` (per_datum_logprobs only, no total_loss / + # metrics) and drop the ``"forward": "forward_backward"`` mapping in + # ``_LOOM_SUBPATH_TO_OP_TYPE``. Then this defensive getattr can go away. + first_metrics = getattr(results[0], "metrics", None) or {} + keys = first_metrics.keys() + res: dict = {} + for key in keys: + parts = key.split(":") + if len(parts) != 2: + continue + name, reduction = parts + if reduction not in _REDUCE_MAP: + _logger.debug( + "Invalid reduction=%s for metric name=%s. Expecting one of %s", + reduction, name, list(_REDUCE_MAP.keys()), + ) + continue + if not all(key in (getattr(m, "metrics", None) or {}) for m in results): + continue + values = [(getattr(m, "metrics", None) or {})[key] for m in results] + reduce_fn = _REDUCE_MAP[reduction] + + if reduction in ("mean", "slack"): + res[key] = reduce_fn(values, chunk_sizes) + elif reduction == "unique": + res[key] = values[0] + res.update({f"{key}_{i + 1}": v for i, v in enumerate(values[1:])}) + else: + res[key] = reduce_fn(values) + return res + + +def _combine_fwd_bwd_results( + results: List[ForwardBackwardOperationResult], + chunk_sizes: List[int], +) -> ForwardBackwardOperationResult: + """Combine results from multiple forward_backward chunks.""" + if not results: + return ForwardBackwardOperationResult(total_loss=0.0) + + combined_metrics = _metrics_reduction(results, chunk_sizes) + combined_logprobs: List[TensorData] = [] + for r in results: + if r.per_datum_logprobs: + combined_logprobs.extend(r.per_datum_logprobs) + # Combine loss_fn_outputs (extra JSON field carrying per-datum logprobs + # from the Loom server). The cookbook reads this field first, falling + # back to per_datum_logprobs only when it is absent. + combined_lfo: list = [] + for r in results: + lfo = r.get("loss_fn_outputs") if hasattr(r, "get") else None + if lfo: + combined_lfo.extend(lfo) + total_loss = sum(r.total_loss for r in results) + combined: dict = { + "total_loss": total_loss, + "per_datum_logprobs": combined_logprobs or None, + "metrics": combined_metrics or None, + } + if combined_lfo: + combined["loss_fn_outputs"] = combined_lfo + return ForwardBackwardOperationResult(combined) + + +def _normalize_loom_result(data: dict, op_type: str, request_id: str) -> dict: + """Normalize the Loom poll-endpoint wire format into an OperationResult dict. + + The Loom server returns raw engine results (no ``"type"`` discriminator, metrics + under namespaced keys like ``"total_loss:sum"``). This function injects the + discriminator and promotes metric fields so ``_deserialize(OperationResult, ...)`` + returns the correct typed subclass. + """ + out = dict(data) + out.setdefault("type", op_type) + out.setdefault("operation_id", request_id) + out.setdefault("status", "succeeded") + + metrics: dict = out.get("metrics") or {} + + if op_type == "forward_backward": + if "total_loss" not in out: + out["total_loss"] = float(metrics.get("total_loss:sum", 0.0)) + + elif op_type == "optim_step": + if "grad_norm" not in out: + out["grad_norm"] = float(metrics.get("skyrl.ai/grad_norm", 0.0)) + if "step_count" not in out: + out["step_count"] = int(metrics.get("step_count", 0)) + + elif op_type == "save_sampler_weights": + # Server may return "type": "save_weights_for_sampler" — normalise to SDK value. + out["type"] = "save_sampler_weights" + out.setdefault("checkpoint_id", out.get("checkpoint_id", "")) + out.setdefault("sampling_session_id", out.get("sampling_session_id", "")) + + elif op_type == "save_checkpoint": + out["type"] = "save_checkpoint" # force, in case server returns a different value + out.setdefault("checkpoint_id", out.get("checkpoint_id", "")) + out.setdefault("path", out.get("path", "")) + + return out + +_PREVIEW = FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW +_API_VERSION = "v1" +_logger = _logging.getLogger(__name__) + +#: Per-call HTTP timeout for the (now short-poll) retrieve-status endpoint. +#: The server returns immediately, so this only needs to cover network RTT and +#: a one-shot DB read. +_RETRIEVE_TIMEOUT = 15.0 + +#: Adaptive poll backoff bounds (seconds) for the retrieve-status endpoint. +#: We start at MIN (catches fast operations cheaply) and double up to MAX +#: (bounds RPS for long-running operations). Backoff resets on every new +#: poll loop (i.e. per-request), so each future starts polling at MIN. +_RETRIEVE_POLL_MIN = 1.0 +_RETRIEVE_POLL_MAX = 30.0 + + +def _operation_timeout_from_env() -> Optional[float]: + env_name = "AZURE_AI_FINETUNING_SESSIONS_OPERATION_TIMEOUT_SEC" + raw = _os.environ.get(env_name) + if raw is None: + _logger.debug("%s not set; using default 3600s operation timeout", env_name) + return 3600.0 # not set -> default + try: + value = float(raw) + except ValueError: + _logger.warning("Invalid %s=%r; using default 3600s", env_name, raw) + return 3600.0 # invalid -> default + if value > 0: + _logger.debug("%s=%s; operation timeout set to %.0fs", env_name, raw, value) + return value + _logger.warning("%s=%s; operation timeout DISABLED (retries can run unbounded)", env_name, raw) + return None # 0 (or negative) -> disabled + + +# Per-operation polling timeout for training/sample/checkpoint requests. This +# bounds retry loops on repeated 5xx/network failures while still allowing long +# operations; set env var to 0 (or negative) to disable. +_DEFAULT_OPERATION_TIMEOUT_SEC = _operation_timeout_from_env() + + +class _ErrorBudget: + """Bounds a sustained error streak; healthy progress is unbounded. + + ``budget_sec=None`` disables the budget. ``on_exhausted(reason, budget_sec)`` + builds the exception to raise, where ``reason`` is the short error label + (e.g. ``"HTTP 503"``) and ``budget_sec`` is the configured budget in seconds. + """ + + def __init__( + self, + budget_sec: Optional[float], + *, + on_exhausted: Callable[[str, float], BaseException], + ) -> None: + self._budget = budget_sec + self._on_exhausted = on_exhausted + self._deadline: Optional[float] = None + + @classmethod + def for_polling( + cls, + budget_sec: Optional[float], + *, + op_type: str, + request_id: str, + ) -> "_ErrorBudget": + """Build a poll-loop budget that raises ``TimeoutError`` when exhausted. + + Shared by the sync and async poll loops so the exhaustion message lives + in one place. + """ + return cls( + budget_sec, + on_exhausted=lambda reason, budget: TimeoutError( + f"Timed out after {budget:.0f}s of sustained " + f"errors ({reason}) waiting for {op_type or 'operation'} request {request_id}" + ), + ) + + def clear(self) -> None: + """Disarm; call on every healthy poll.""" + self._deadline = None + + def consume(self, reason: str) -> None: + """Arm on the first error, raise once the streak outlasts the budget.""" + if self._budget is None: + return + now = _time.monotonic() + if self._deadline is None: + self._deadline = now + self._budget + elif now > self._deadline: + raise self._on_exhausted(reason, self._budget) + + +# Poll-progress state shared by sync and async clients. +_poll_log_last: dict[tuple[str, str], float] = {} +_request_active_since: dict[tuple[str, str], float] = {} +_POLL_LOG_DEDUP_SEC = 30.0 + + +def _maybe_log_poll_progress( + envelope: dict, + session_id: str, + request_id: str, + op_type: str, + elapsed: float, +) -> None: + """Emit a throttled poll-progress log for a pending request.""" + now = _time.monotonic() + is_queued = envelope.get("phase") == "resuming_session" + req_key = (session_id, request_id) + + if is_queued: + # Restart active timing if a request moves back to the queue. + _request_active_since.pop(req_key, None) + display_elapsed = elapsed + else: + active_since = _request_active_since.get(req_key) + if active_since is None: + _request_active_since[req_key] = now + display_elapsed = 0.0 + else: + display_elapsed = now - active_since + + if display_elapsed < _POLL_LOG_DEDUP_SEC: + return + dedup_key = (session_id, op_type) + if now - _poll_log_last.get(dedup_key, 0.0) < _POLL_LOG_DEDUP_SEC: + return + _poll_log_last[dedup_key] = now + if is_queued: + _logger.info( + "[poller] %s/%s (op=%s) queued waiting for capacity \u2014 %.0fs elapsed", + session_id, request_id, op_type, display_elapsed, + ) + else: + _logger.info( + "[poller] %s/%s (op=%s) in progress \u2014 %.0fs elapsed", + session_id, request_id, op_type, display_elapsed, + ) + + +def _clear_poll_log_state(session_id: str, request_id: str, op_type: str) -> None: + """Drop poll-progress state for a finished request.""" + _request_active_since.pop((session_id, request_id), None) + _poll_log_last.pop((session_id, op_type), None) + +# --------------------------------------------------------------------------- +# Verbose HTTP logging toggle +# --------------------------------------------------------------------------- +# Set to True (or set env var FINETUNING_VERBOSE_HTTP=1) to log every request +# URL + body and response status + body for all SDK API calls. +VERBOSE_HTTP: bool = _os.environ.get("FINETUNING_VERBOSE_HTTP", "").lower() in ("1", "true", "yes") + + +def _log_http(direction: str, method: str, url: str, status: Optional[int] = None, body: Any = None) -> None: + """Log an HTTP request or response if VERBOSE_HTTP is enabled.""" + if not VERBOSE_HTTP: + return + body_str = "" + if body is not None: + try: + body_str = _json.dumps(body, indent=2) if isinstance(body, (dict, list)) else str(body) + except Exception: # pragma: no cover + body_str = repr(body) + body_str = f"\n{body_str}" + if direction == "request": + _logger.info("[HTTP] --> %s %s%s", method, url, body_str) + else: + _logger.info("[HTTP] <-- %s %s status=%d%s", method, url, status or 0, body_str) + + +def _base_headers(extra: Optional[dict] = None) -> dict: + """Build the common headers for every Loom request. + + Reads ``X_COGNITIVE_SUBSCRIPTION_ID`` (or ``COGNITIVE_SUBSCRIPTION_ID`` / + ``AZURE_SUBSCRIPTION_ID`` as fallbacks) from the environment and injects it + as ``apim-subscription-id``. The remote Loom endpoint (LOOM_SETUP_MODE=prod) + requires this header on every request — same as + ``clean_remote.sh`` which sets ``X_COGNITIVE_SUBSCRIPTION_ID=local-sub``. + """ + headers: dict = { + "Accept": "application/json", + "Foundry-Features": _PREVIEW.value, + } + sub_id = ( + _os.environ.get("X_COGNITIVE_SUBSCRIPTION_ID") + or _os.environ.get("COGNITIVE_SUBSCRIPTION_ID") + or _os.environ.get("AZURE_SUBSCRIPTION_ID") + ) + if sub_id: + headers["apim-subscription-id"] = sub_id + if extra: + headers.update(extra) + return headers + + +class FineTuningSessionClient(FineTuningSessionClientGenerated): # pylint: disable=client-accepts-api-version-keyword + """FineTuningSessionClient. + + :ivar sessions: SessionsOperations operations + :vartype sessions: azure.ai.finetuning_sessions.operations.SessionsOperations + :ivar training: TrainingOperations operations + :vartype training: azure.ai.finetuning_sessions.operations.TrainingOperations + :ivar checkpoints: CheckpointsOperations operations + :vartype checkpoints: azure.ai.finetuning_sessions.operations.CheckpointsOperations + :ivar sampling: SamplingOperations operations + :vartype sampling: azure.ai.finetuning_sessions.operations.SamplingOperations + :ivar operations: Operations operations + :vartype operations: azure.ai.finetuning_sessions.operations.Operations + :param endpoint: Foundry Project endpoint in the form + "https://{ai-services-account-name}.services.ai.azure.com/api/projects/{project-name}". If you + only have one Project in your Foundry Hub, or to target the default Project in your Hub, use + the form "https://{ai-services-account-name}.services.ai.azure.com/api/projects/_project". + Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Required. + :type credential: ~azure.core.credentials.TokenCredential + :keyword int polling_interval: Default waiting time between two polls for LRO operations if no + Retry-After header is present. + """ + def __init__(self, endpoint: str, credential: "TokenCredential", *, allow_insecure_http: bool = False, + **kwargs: Any) -> None: + provided_policies = kwargs.get("policies") + original_kwargs = dict(kwargs) + super().__init__(endpoint=endpoint, credential=credential, allow_insecure_http=allow_insecure_http, **original_kwargs) + + _policies = provided_policies + if _policies is None: + _policies = [ + policies.RequestIdPolicy(**original_kwargs), + self._config.headers_policy, + self._config.user_agent_policy, + self._config.proxy_policy, + policies.ContentDecodePolicy(**original_kwargs), + self._config.redirect_policy, + self._config.retry_policy, + self._config.authentication_policy, + self._config.custom_hook_policy, + self._config.logging_policy, + policies.DistributedTracingPolicy(**original_kwargs), + policies.SensitiveHeaderCleanupPolicy(**original_kwargs) if self._config.redirect_policy else None, + self._config.http_logging_policy, + ] + + self._session_client = PipelineClient(base_url=endpoint, policies=_policies, **original_kwargs) + self.sessions._client = self._session_client + + +class FineTuningSession: + """Convenience wrapper around a single fine-tuning session. + + Mirrors the hero-code surface from SPEC_FOUNDRY_AICLIENT.md so callers + can write training loops without constructing raw request bodies. + + :param client: The generated ``FineTuningSessionClient``. + :param session_id: The session ID returned by the server after creating a session. + """ + + def __init__(self, client: "FineTuningSessionClient", session_id: str) -> None: + self._client = client + self.session_id = session_id + # Derive the heartbeat session_id: heartbeat endpoint looks up by + # "session_xxx" in the sessions table, not "model_xxx". + raw_id = session_id.removeprefix("model_") + self._heartbeat_session_id = f"session_{raw_id}" + self._heartbeat_stop = _threading.Event() + self._heartbeat_thread: Optional[_threading.Thread] = None + self._start_heartbeat() + + # ── Background heartbeat ────────────────────────────────────────────────── + + def _start_heartbeat(self, interval_sec: float = 30.0) -> None: + """Start a daemon thread that sends heartbeat every interval_sec.""" + def _heartbeat_loop() -> None: + while not self._heartbeat_stop.wait(interval_sec): + try: + self.heartbeat() + except Exception as exc: + _logger.warning("[heartbeat] failed for %s: %s", self._heartbeat_session_id, exc) + + self._heartbeat_thread = _threading.Thread( + target=_heartbeat_loop, name="fts-heartbeat", daemon=True + ) + self._heartbeat_thread.start() + _logger.info("[heartbeat] started (interval=%.0fs, session=%s)", interval_sec, self._heartbeat_session_id) + + def _stop_heartbeat(self) -> None: + """Stop the background heartbeat thread.""" + self._heartbeat_stop.set() + if self._heartbeat_thread is not None: + self._heartbeat_thread.join(timeout=5.0) + self._heartbeat_thread = None + + # ── Factory ─────────────────────────────────────────────────────────────── + + @classmethod + def create( + cls, + client: "FineTuningSessionClient", + *, + base_model: str, + lora_config: Optional[LoRAConfig] = None, + type: str = "training", + from_checkpoint: Optional[FromCheckpoint] = None, + timeout_sec: float = 600.0, + ) -> "FineTuningSession": + """Create a fine-tuning session and wait until the model is loaded. + + Combines ``POST /fine_tuning/sessions`` (which triggers an async model-load + on the server) with polling of the returned ``request_id`` until the load + completes, then returns a ready-to-use :class:`FineTuningSession`. + + :param client: The :class:`~azure.ai.finetuning_sessions.FineTuningSessionClient`. + :param base_model: Name of the base model to load (e.g. ``"Llama-3.1-8B"``). + :param lora_config: Optional LoRA adapter config. Server default is used if omitted. + :param type: Session type string. Defaults to ``"training"``. + :param from_checkpoint: Optional :class:`FromCheckpoint` specifying the + source session and checkpoint to bootstrap from (continual fine-tuning + / resume from checkpoint). + :param timeout_sec: Maximum seconds to wait for the model to load. Defaults to ``600.0``. + :note: Poll cadence is controlled by an internal adaptive backoff + (``_RETRIEVE_POLL_MIN`` doubling up to ``_RETRIEVE_POLL_MAX``); + it is not currently caller-configurable. + :raises RuntimeError: If the server reports status ``"failed"`` or the timeout expires. + :return: A :class:`FineTuningSession` instance ready for training operations. + """ + create_request = CreateSessionRequest( + type=type, + base_model=base_model, + lora_config=lora_config, + ) + body = _json.loads(_json.dumps( + create_request, + cls=_SdkJSONEncoder, + exclude_readonly=True, + )) + if from_checkpoint is not None: + body["from_checkpoint"] = _json.loads( + _json.dumps(from_checkpoint, cls=_SdkJSONEncoder, exclude_readonly=True) + ) + body_json = _json.dumps(body) + post_req = _HttpRequest( + "POST", + "{endpoint}/fine_tuning/sessions", + headers=_base_headers({"Content-Type": "application/json"}), + params={"api-version": _API_VERSION}, + content=body_json, + ) + _log_http("request", "POST", "/fine_tuning/sessions", body=_json.loads(body_json)) + post_resp = client.send_request(post_req) + _log_http("response", "POST", "/fine_tuning/sessions", status=post_resp.status_code, body=post_resp.json()) + if post_resp.status_code >= 400: + try: + resp_body = post_resp.json() + except Exception: + resp_body = None + typed = _classify_http_error( + post_resp.status_code, resp_body, response=post_resp + ) + if typed is not None: + raise typed + post_resp.raise_for_status() + data = post_resp.json() + raw_session_id: str = data["session_id"] + request_id: str = data["request_id"] + _logger.info("[create] POST /fine_tuning/sessions response: raw_session_id=%s, request_id=%s, full_response=%s", raw_session_id, request_id, data) + + # The loom server stores the model record as f"model_{session_id}" (see + # loom_create_model in the SQL/Cosmos providers). Every route handler that + # performs a training operation calls loom_require_model(provider, ) + # which does loom_get_model(). Using the "model_" prefixed form as + # our session_id means the URL path parameter matches the stored model_id, so + # the lookup succeeds — no changes required on the server or engine side. + session_id: str = f"model_{raw_session_id}" + _logger.info("[create] session_id transformed: raw=%s -> resource_id=%s (used in all subsequent URL paths)", raw_session_id, session_id) + + # Wait for the model-load request to complete. + # The retrieve-status endpoint (GET /fine_tuning/sessions/{id}/request/{rid}) + # is now non-blocking: each call returns a {status, result, error} envelope + # immediately. We short-poll until status=="completed" (or "failed") + # using adaptive backoff (MIN doubling up to MAX) to keep poll RPS bounded + # for slow model loads while still reacting quickly when ready. + deadline = _time.monotonic() + timeout_sec + _create_conn_backoff = 1.0 + _create_poll_backoff = _RETRIEVE_POLL_MIN + _create_poll_start = _time.monotonic() + while True: + try: + poll_req = _HttpRequest( + "GET", + "{endpoint}" + f"/fine_tuning/sessions/{session_id}/request/{request_id}", + headers=_base_headers(), + params={"api-version": _API_VERSION}, + ) + poll_path = f"/fine_tuning/sessions/{session_id}/request/{request_id}" + _log_http("request", "GET", poll_path) + poll_resp = client.send_request(poll_req) + # Parse JSON once — `azure.core.rest` responses do not guarantee + # the body can be read more than once. + envelope = poll_resp.json() if poll_resp.status_code == 200 else None + _log_http("response", "GET", poll_path, status=poll_resp.status_code, body=envelope) + + if poll_resp.status_code == 200: + env_status = envelope.get("status") + if env_status == "completed": + _logger.info("[create] model load completed: %s", envelope) + _clear_poll_log_state(session_id, request_id, "create_session") + break + if env_status == "failed": + _clear_poll_log_state(session_id, request_id, "create_session") + typed = _classify_poll_failure(envelope, session_id=session_id) + if typed is not None: + raise typed + raise RuntimeError( + f"Model load failed for session_id={raw_session_id} " + f"[{envelope.get('error_code') or envelope.get('code') or 'unknown'}]: " + f"{envelope.get('error') or 'unknown error'} " + f"(debug_ref={envelope.get('debug_ref') or 'n/a'})" + ) + # pending -> sleep with adaptive backoff and retry (subject to deadline). + if _time.monotonic() > deadline: + raise RuntimeError( + f"Timed out after {timeout_sec}s waiting for session_id={raw_session_id} to become ready" + ) + elapsed = _time.monotonic() - _create_poll_start + _maybe_log_poll_progress(envelope, session_id, request_id, "create_session", elapsed) + _create_conn_backoff = 1.0 # reset on successful HTTP exchange + _time.sleep(_create_poll_backoff) + _create_poll_backoff = min(_create_poll_backoff * 2, _RETRIEVE_POLL_MAX) + continue + + # Retry on 5xx and on transient client-side conditions: + # 408 Request Timeout -- intermittent network/proxy timeout + # 429 Too Many Requests -- server-side throttling + # Both are safe to retry with the same adaptive backoff used + # for pending polls. + if ( + 500 <= poll_resp.status_code < 600 + or poll_resp.status_code in (408, 429) + ): + if _time.monotonic() > deadline: + raise RuntimeError( + f"Timed out after {timeout_sec}s waiting for session_id={raw_session_id} to become ready" + ) + elapsed = _time.monotonic() - _create_poll_start + _logger.debug( + "[poller] retry on %s/%s after HTTP %d (%.0fs elapsed)", + session_id, request_id, poll_resp.status_code, elapsed, + ) + _create_conn_backoff = 1.0 + # Honor Retry-After header if present. + retry_after = poll_resp.headers.get("Retry-After") + if retry_after is not None: + try: + poll_wait = float(retry_after) + except (ValueError, TypeError): + poll_wait = _RETRIEVE_POLL_MIN + else: + poll_wait = _RETRIEVE_POLL_MIN + _time.sleep(poll_wait) + continue + + # Any other error — fail immediately + try: + poll_body = poll_resp.json() + except Exception: + poll_body = None + typed = _classify_http_error( + poll_resp.status_code, poll_body, response=poll_resp, session_id=session_id + ) + if typed is not None: + raise typed + poll_resp.raise_for_status() + + except (_ServiceRequestError, _ServiceResponseError) as exc: + # Transient network error — exponential backoff then retry. + if _time.monotonic() > deadline: + raise RuntimeError( + f"Timed out after {timeout_sec}s waiting for session_id={raw_session_id} to become ready" + ) from exc + elapsed = _time.monotonic() - _create_poll_start + _logger.warning( + "[poller] retry on %s/%s after %s(%s) (%.0fs elapsed), backoff %.1fs", + session_id, request_id, type(exc).__name__, exc, elapsed, _create_conn_backoff, + ) + _time.sleep(_create_conn_backoff) + _create_conn_backoff = min(_create_conn_backoff * 2, 30.0) + continue + + return cls(client, session_id=session_id) + + @classmethod + def create_from_checkpoint( + cls, + client: "FineTuningSessionClient", + *, + checkpoint_path: str, + base_model: str, + lora_config: Optional[LoRAConfig] = None, + type: str = "training", + timeout_sec: float = 600.0, + ) -> "FineTuningSession": + """Create a session resumed from a previously saved training checkpoint. + + This is a convenience wrapper around :meth:`create` that parses a + checkpoint path string and passes it as ``from_checkpoint``. + + The new session's LoRA weights, optimizer state, and scheduler step are + all bootstrapped from the checkpoint — equivalent to calling ``create`` + with ``from_checkpoint=FromCheckpoint(source_session_id=..., checkpoint_id=...)``. + + :param client: The :class:`~azure.ai.finetuning_sessions.FineTuningSessionClient`. + :param checkpoint_path: Reference to a saved training checkpoint. + Accepted formats: + - ``"/"`` + - ``"model_/"`` + :param base_model: Base model name. Must match the checkpoint's source. + :param lora_config: Optional LoRA config override. + :param type: Session type. Defaults to ``"training"``. + :param timeout_sec: Maximum seconds to wait for model load. + :raises ValueError: If ``checkpoint_path`` cannot be parsed. + :return: A ready-to-use :class:`FineTuningSession`. + """ + parts = checkpoint_path.split("/") + if len(parts) != 2 or not parts[0] or not parts[1]: + raise ValueError( + "checkpoint_path must be '/' with exactly one '/' separator, " + f"got: {checkpoint_path!r}" + ) + source_session_id, checkpoint_id = parts + # Normalize: the API expects the model_ prefix on source_session_id + if not source_session_id.startswith("model_"): + source_session_id = f"model_{source_session_id}" + return cls.create( + client, + base_model=base_model, + lora_config=lora_config, + type=type, + from_checkpoint=FromCheckpoint( + source_session_id=source_session_id, + checkpoint_id=checkpoint_id, + ), + timeout_sec=timeout_sec, + ) + + # ── Low-level helper ────────────────────────────────────────────────────── + + def _post_and_poll(self, subpath: str, body_model: Any, extra_params: Optional[dict] = None, extra_result_fields: Optional[dict] = None) -> OperationResult: + """POST to a loom action endpoint (returns 200 + request_id), then + long-poll GET /request/{request_id} until the GPU finishes. + + Loom returns 200 (not 202) with ``{request_id, session_id, status}`` + from all mutating operations. The poll endpoint blocks server-side + (up to 5 minutes) and returns the typed result directly. + + Retries 408 / 5xx / transient network errors. The timeout is an ERROR + budget, not a wall-clock budget — see below. Set + AZURE_AI_FINETUNING_SESSIONS_OPERATION_TIMEOUT_SEC=0 to disable it. + + Timeout policy — IMPORTANT: + The budget (``_DEFAULT_OPERATION_TIMEOUT_SEC``, default 3600s) bounds + how long we tolerate a **sustained error streak**, NOT how long the + operation may take. It is per operation, not per job. + + * Healthy progress is NOT bounded. While the server keeps returning a + pending 200 — whether the request is queued waiting for GPU capacity + or actively in progress — we keep polling indefinitely. A healthy + poll CLEARS the error budget. + * Errors ARE bounded. The first 5xx / 408 / 429 / transient network + error after a healthy poll arms the error deadline + (``_DEFAULT_OPERATION_TIMEOUT_SEC`` from that moment). Further errors + do NOT extend it; the next healthy 200 disarms it. If errors persist + past the budget we raise ``TimeoutError`` so a real backend outage + fails fast instead of hanging. + + Net effect: a request can sit in the capacity queue for hours without + being killed, but a stuck/erroring backend is surfaced within the + budget. Note this means a server that returns healthy-pending forever + (never completes, never errors) will poll forever — guard against that + with server-side stall detection, not this client timeout. + """ + body_json = _json.dumps(body_model, cls=_SdkJSONEncoder, exclude_readonly=True) + post_params: dict = {"api-version": _API_VERSION} + if extra_params: + post_params.update(extra_params) + post_req = _HttpRequest( + "POST", + "{endpoint}" + subpath, + headers=_base_headers({"Content-Type": "application/json"}), + params=post_params, + content=body_json, + ) + _log_http("request", "POST", subpath, body=_json.loads(body_json)) + post_resp = self._client.send_request(post_req) + _log_http("response", "POST", subpath, status=post_resp.status_code, body=post_resp.json()) + if post_resp.status_code >= 400: + try: + resp_body = post_resp.json() + except Exception: + resp_body = None + typed = _classify_http_error( + post_resp.status_code, resp_body, response=post_resp, session_id=self.session_id + ) + if typed is not None: + raise typed + post_resp.raise_for_status() + data = post_resp.json() + request_id = data["request_id"] + session_id = data.get("session_id", self.session_id) + + # Long-poll the result directly so we can normalize the Loom wire format + # before deserializing. The generated operations.get() passes the raw JSON + # straight to _deserialize(OperationResult, ...) which expects a "type" + # discriminator field — but the Loom server returns its own engine format + # (no "type", metrics under namespaced keys). We do the GET ourselves, + # normalize, then deserialize. + op_type = _LOOM_SUBPATH_TO_OP_TYPE.get(subpath.rsplit("/", 1)[-1], "") + poll_req = _HttpRequest( + "GET", + "{endpoint}" + f"/fine_tuning/sessions/{session_id}/request/{request_id}", + headers=_base_headers(), + params={"api-version": _API_VERSION}, + ) + poll_path = f"/fine_tuning/sessions/{session_id}/request/{request_id}" + + # Short-poll the {status, result, error} envelope. The server returns + # immediately on every call; we sleep with adaptive backoff (MIN doubling + # up to MAX) between pending polls. + # + # Timeout policy: the budget is an ERROR budget, not a wall-clock budget. + # * A healthy pending 200 (queued waiting for capacity, or in progress) + # does NOT consume the budget — it CLEARS it. So a request can sit in + # the capacity queue indefinitely as long as the server keeps + # reporting healthy progress. + # * The first 5xx / 408 / 429 / transient network error after a healthy + # poll ARMS the error deadline (`_DEFAULT_OPERATION_TIMEOUT_SEC` from + # now). Subsequent errors do NOT extend it; the next healthy 200 + # disarms it. If the error streak outlasts the budget we raise + # TimeoutError (fail fast on a real outage). + # Set the env var <= 0 to disable the error budget (retry forever). + error_budget = _ErrorBudget.for_polling( + _DEFAULT_OPERATION_TIMEOUT_SEC, op_type=op_type, request_id=request_id + ) + + connection_error_backoff = 1.0 + poll_backoff = _RETRIEVE_POLL_MIN + result_data: Any = None + poll_start = _time.monotonic() + while True: + try: + _log_http("request", "GET", poll_path) + poll_resp = self._client.send_request(poll_req) + + if poll_resp.status_code == 200: + envelope = poll_resp.json() + _log_http("response", "GET", poll_path, status=200, body=envelope) + env_status = envelope.get("status") + if env_status == "completed": + result_data = envelope.get("result") or {} + _clear_poll_log_state(session_id, request_id, op_type) + break + if env_status == "failed": + _clear_poll_log_state(session_id, request_id, op_type) + typed = _classify_poll_failure(envelope, session_id=session_id) + if typed is not None: + raise typed + raise RuntimeError( + f"Request failed " + f"[{envelope.get('error_code') or envelope.get('code') or 'unknown'}]: " + f"{envelope.get('error') or 'no error message'} " + f"(debug_ref={envelope.get('debug_ref') or 'n/a'})" + ) + # pending -> healthy progress: clear the error budget (queued + # / in-progress time is unbounded), then sleep with backoff. + elapsed = _time.monotonic() - poll_start + _maybe_log_poll_progress(envelope, session_id, request_id, op_type, elapsed) + error_budget.clear() + connection_error_backoff = 1.0 + _time.sleep(poll_backoff) + poll_backoff = min(poll_backoff * 2, _RETRIEVE_POLL_MAX) + continue + + # Retry on 5xx and on transient 408/429 (timeout / throttling). + if ( + 500 <= poll_resp.status_code < 600 + or poll_resp.status_code in (408, 429) + ): + elapsed = _time.monotonic() - poll_start + _logger.debug( + "[poller] retry on %s/%s after HTTP %d (%.0fs elapsed)", + session_id, request_id, poll_resp.status_code, elapsed, + ) + error_budget.consume(f"HTTP {poll_resp.status_code}") + connection_error_backoff = 1.0 + # Honor Retry-After header if present. + retry_after = poll_resp.headers.get("Retry-After") + if retry_after is not None: + try: + poll_wait = float(retry_after) + except (ValueError, TypeError): + poll_wait = _RETRIEVE_POLL_MIN + else: + poll_wait = _RETRIEVE_POLL_MIN + _time.sleep(poll_wait) + continue + + # Non-retryable HTTP error (4xx other than 408/429). + _log_http("response", "GET", poll_path, status=poll_resp.status_code, body=None) + try: + poll_body = poll_resp.json() + except Exception: + poll_body = None + typed = _classify_http_error( + poll_resp.status_code, poll_body, response=poll_resp, session_id=session_id + ) + if typed is not None: + raise typed + poll_resp.raise_for_status() + + except (_ServiceRequestError, _ServiceResponseError) as exc: + # Transient network error — exponential backoff then retry. + elapsed = _time.monotonic() - poll_start + _logger.warning( + "[poller] retry on %s/%s after %s(%s) (%.0fs elapsed), backoff %.1fs", + session_id, request_id, type(exc).__name__, exc, elapsed, connection_error_backoff, + ) + error_budget.consume(type(exc).__name__) + _time.sleep(connection_error_backoff) + connection_error_backoff = min(connection_error_backoff * 2, 30.0) + continue + + normalized = _normalize_loom_result(result_data, op_type, request_id) + # Merge caller-supplied fallback fields BEFORE deserialization so that + # generated model classes receive them (post-deserialization attr + # assignment doesn't work on SDK objects with __slots__). + if extra_result_fields: + for k, v in extra_result_fields.items(): + if not normalized.get(k): # server value takes precedence + normalized[k] = v + return _deserialize_model(OperationResult, normalized) + + # ── Training ────────────────────────────────────────────────────────────── + + def forward_backward( + self, + batch: List[Datum], + *, + loss_fn: Union[str, LossFn] = LossFn.CROSS_ENTROPY, + loss_fn_config: Optional[LossFnConfig] = None, + **kwargs: Any, + ) -> OperationResult: + """Submit a mini-batch for a forward + backward pass. + + If the batch exceeds ``_MAX_CHUNK_LEN`` datums or ``_MAX_CHUNK_BYTES`` + estimated payload size, the batch is automatically split into chunks. + Chunks are submitted in parallel and results are combined. + + Spec: ``fb_result = session.forward_backward(batch, loss_fn="cross_entropy")`` + + :param batch: List of :class:`~azure.ai.finetuning_sessions.models.Datum`. + :param loss_fn: Loss function name. Defaults to ``"cross_entropy"``. + :param loss_fn_config: Optional per-loss hyper-parameters. + :return: :class:`~azure.ai.finetuning_sessions.models.OperationResult`. + """ + chunks = _chunk_data(batch) + if len(chunks) <= 1: + # Single chunk — no combining needed. + return self._post_and_poll( + f"/fine_tuning/sessions/{self.session_id}/forward_backward", + ForwardBackwardRequest( + forward_backward_input=ForwardBackwardInput( + data=batch, + loss_fn=loss_fn, + loss_fn_config=loss_fn_config, + ) + ), + ) + + _logger.info( + "[forward_backward] batch of %d datums split into %d chunks: %s", + len(batch), len(chunks), [len(c) for c in chunks], + ) + + def _submit_chunk(idx_chunk: tuple) -> ForwardBackwardOperationResult: + i, chunk = idx_chunk + _logger.info("[forward_backward] sending chunk %d/%d (%d datums)", i + 1, len(chunks), len(chunk)) + result = self._post_and_poll( + f"/fine_tuning/sessions/{self.session_id}/forward_backward", + ForwardBackwardRequest( + forward_backward_input=ForwardBackwardInput( + data=chunk, + loss_fn=loss_fn, + loss_fn_config=loss_fn_config, + ) + ), + ) + if isinstance(result, ForwardBackwardOperationResult): + return result + return ForwardBackwardOperationResult( + total_loss=getattr(result, "total_loss", 0.0), + per_datum_logprobs=getattr(result, "per_datum_logprobs", None), + metrics=getattr(result, "metrics", None), + ) + + # Fire all chunks in parallel. + # Wall-clock time ≈ max(chunk times) instead of sum. + with _futures.ThreadPoolExecutor(max_workers=len(chunks)) as pool: + chunk_results = list(pool.map(_submit_chunk, enumerate(chunks))) + + chunk_sizes = [len(c) for c in chunks] + return _combine_fwd_bwd_results(chunk_results, chunk_sizes) + + def optim_step( + self, + adam_params: AdamParams, + **kwargs: Any, + ) -> OperationResult: + """Apply accumulated gradients with Adam. + + Blocks until the GPU applies the weight update. + + Spec: ``opt_result = session.optim_step(AdamParams(learning_rate=1e-4))`` + """ + return self._post_and_poll( + f"/fine_tuning/sessions/{self.session_id}/optim_step", + OptimStepRequest(adam_params=adam_params), + ) + + def forward( + self, + batch: List[Datum], + *, + loss_fn: Union[str, LossFn] = LossFn.CROSS_ENTROPY, + loss_fn_config: Optional[LossFnConfig] = None, + **kwargs: Any, + ) -> OperationResult: + """Submit a mini-batch for a forward-only pass (no gradient accumulation). + + Returns the same result shape as ``forward_backward`` but does not + accumulate gradients on the worker, making it safe for evaluation. + + If the batch exceeds ``_MAX_CHUNK_LEN`` datums or ``_MAX_CHUNK_BYTES`` + estimated payload size, the batch is automatically split into chunks. + Chunks are submitted in parallel and results are combined. + + :param batch: List of :class:`~azure.ai.finetuning_sessions.models.Datum`. + :param loss_fn: Loss function name. Defaults to ``"cross_entropy"``. + :param loss_fn_config: Optional per-loss hyper-parameters. + :return: :class:`~azure.ai.finetuning_sessions.models.OperationResult`. + """ + # Server expects a ForwardRequest with `forward_input` wrapping the + # shared ForwardBackwardInput payload. + subpath = f"/fine_tuning/sessions/{self.session_id}/forward" + + def _build_body(chunk: List[Datum]) -> dict: + return { + "forward_input": ForwardBackwardInput( + data=chunk, + loss_fn=loss_fn, + loss_fn_config=loss_fn_config, + ) + } + + chunks = _chunk_data(batch) + if len(chunks) <= 1: + return self._post_and_poll(subpath, _build_body(batch)) + + _logger.info( + "[forward] batch of %d datums split into %d chunks: %s", + len(batch), len(chunks), [len(c) for c in chunks], + ) + + def _submit_chunk(idx_chunk: tuple) -> ForwardBackwardOperationResult: + i, chunk = idx_chunk + _logger.info("[forward] sending chunk %d/%d (%d datums)", i + 1, len(chunks), len(chunk)) + result = self._post_and_poll(subpath, _build_body(chunk)) + if isinstance(result, ForwardBackwardOperationResult): + return result + return ForwardBackwardOperationResult( + total_loss=getattr(result, "total_loss", 0.0), + per_datum_logprobs=getattr(result, "per_datum_logprobs", None), + metrics=getattr(result, "metrics", None), + ) + + with _futures.ThreadPoolExecutor(max_workers=len(chunks)) as pool: + chunk_results = list(pool.map(_submit_chunk, enumerate(chunks))) + + chunk_sizes = [len(c) for c in chunks] + return _combine_fwd_bwd_results(chunk_results, chunk_sizes) + + # ── Checkpoints ─────────────────────────────────────────────────────────── + + def save_weights( + self, + path: str, + **kwargs: Any, + ) -> OperationResult: + """Save a training checkpoint (LoRA weights + optimizer state). + + Blocks until the checkpoint is written to storage. + + Spec: ``ckpt_result = session.save_weights("sft_piglatin_v1")`` + """ + return self._post_and_poll( + f"/fine_tuning/sessions/{self.session_id}/checkpoint", + SaveCheckpointRequest(path=path), + ) + + def save_weights_for_sampler( + self, + seq_id: int, + *, + sampling_session_seq_id: Optional[int] = None, + path: Optional[str] = None, + **kwargs: Any, + ) -> OperationResult: + """Push current LoRA weights to the sampler (required before calling ``sample``). + + Blocks until the sampler weights are ready. + + Spec: ``sampler_result = session.save_weights_for_sampler(seq_id=step)`` + + :param seq_id: Training step index -- must match the ``seq_id`` passed to ``sample``. + :param sampling_session_seq_id: Ordinal of this sampling session in the run. + :param path: Optional explicit checkpoint identifier. + """ + # Compute the checkpoint_id using the same formula the server uses + # (loom_sampling.py line 270). The server doesn't echo it back in the + # poll response, so we inject it before deserialization. + computed_checkpoint_id = path or f"ss{sampling_session_seq_id}_seq{seq_id}" + return self._post_and_poll( + f"/fine_tuning/sessions/{self.session_id}/checkpoint_sample", + SaveSamplerWeightsRequest( + seq_id=seq_id, + sampling_session_seq_id=sampling_session_seq_id, + path=path, + ), + extra_result_fields={"checkpoint_id": computed_checkpoint_id}, + ) + + # ── Sampling ────────────────────────────────────────────────────────────── + + def sample( + self, + prompt_tokens: List[int], + sampling_params: SamplingParams, + *, + checkpoint_id: str, + num_samples: int = 1, + sampling_session_id: Optional[str] = None, + seq_id: Optional[int] = None, + prompt_logprobs: bool = False, + topk_prompt_logprobs: int = 0, + **kwargs: Any, + ) -> OperationResult: + """Generate completions using current LoRA weights. + + Blocks until the GPU finishes sampling. + + Spec: ``sample_result = session.sample(prompt_tokens, sampling_params, checkpoint_id=..., num_samples=4, ...)`` + + :param prompt_tokens: Tokenised input prompt as a list of integer IDs. + :param sampling_params: Generation parameters (max_tokens, temperature, etc.). + :param checkpoint_id: Sampler checkpoint ID returned by ``save_weights_for_sampler``. + :param num_samples: Number of independent completions to generate. Default 1. + :param sampling_session_id: ID returned by a prior ``save_weights_for_sampler`` call. + :param seq_id: Training step index; must match the one used in ``save_weights_for_sampler``. + :param prompt_logprobs: If True, return per-token log-probabilities for the prompt. + :param topk_prompt_logprobs: Top-k log-probabilities per prompt token. 0 = none. + """ + return self._post_and_poll( + f"/fine_tuning/sessions/{self.session_id}/sample", + SampleRequest( + num_samples=num_samples, + prompt=ModelInput(chunks=[ModelInputChunk(tokens=prompt_tokens)]), + sampling_params=sampling_params, + topk_prompt_logprobs=topk_prompt_logprobs, + sampling_session_id=sampling_session_id, + seq_id=seq_id, + prompt_logprobs=prompt_logprobs, + ), + extra_params={"checkpoint_id": checkpoint_id}, + ) + + # ── Session lifecycle ───────────────────────────────────────────────────── + + def heartbeat(self, **kwargs: Any) -> Any: + """Refresh an active session to prevent idle expiry.""" + return self._client.sessions.heartbeat( + session_id=self._heartbeat_session_id, + foundry_features=_PREVIEW, + api_version=_API_VERSION, + **kwargs, + ) + + def close(self, **kwargs: Any) -> None: + """Unload the session from the GPU engine. + + Stops the background heartbeat, then issues the complete request. + + Spec: ``session.close()`` + """ + self._stop_heartbeat() + close_req = _HttpRequest( + "POST", + "{endpoint}" + f"/fine_tuning/sessions/{self.session_id}/complete", + headers=_base_headers(), + params={"api-version": _API_VERSION}, + ) + resp = self._client.send_request(close_req) + resp.raise_for_status() + + +__all__: list[str] = ["FineTuningSession", "FineTuningSessionClient"] + + +def patch_sdk(): + """Do not remove from this file. + + `patch_sdk` is a last resort escape hatch that allows you to do customizations + you can't accomplish using the techniques described in + https://aka.ms/azsdk/python/dpcodegen/python/customize + """ diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_types.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_types.py new file mode 100644 index 000000000000..f96d9134fd15 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_types.py @@ -0,0 +1,11 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from typing import Union + +StopCriteria = Union[list[int], list[str]] diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_utils/__init__.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_utils/__init__.py new file mode 100644 index 000000000000..8026245c2abc --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_utils/__init__.py @@ -0,0 +1,6 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_utils/model_base.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_utils/model_base.py new file mode 100644 index 000000000000..db24930fdca9 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_utils/model_base.py @@ -0,0 +1,1441 @@ +# pylint: disable=line-too-long,useless-suppression,too-many-lines +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +# pylint: disable=protected-access, broad-except + +import copy +import calendar +import decimal +import functools +import sys +import logging +import base64 +import re +import typing +import enum +import email.utils +from datetime import datetime, date, time, timedelta, timezone +from json import JSONEncoder +import xml.etree.ElementTree as ET +from collections.abc import MutableMapping +from typing_extensions import Self +import isodate +from azure.core.exceptions import DeserializationError +from azure.core import CaseInsensitiveEnumMeta +from azure.core.pipeline import PipelineResponse +from azure.core.serialization import _Null +from azure.core.rest import HttpResponse + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["SdkJSONEncoder", "Model", "rest_field", "rest_discriminator"] + +TZ_UTC = timezone.utc +_T = typing.TypeVar("_T") +_NONE_TYPE = type(None) + + +def _timedelta_as_isostr(td: timedelta) -> str: + """Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S' + + Function adapted from the Tin Can Python project: https://github.com/RusticiSoftware/TinCanPython + + :param timedelta td: The timedelta to convert + :rtype: str + :return: ISO8601 version of this timedelta + """ + + # Split seconds to larger units + seconds = td.total_seconds() + minutes, seconds = divmod(seconds, 60) + hours, minutes = divmod(minutes, 60) + days, hours = divmod(hours, 24) + + days, hours, minutes = list(map(int, (days, hours, minutes))) + seconds = round(seconds, 6) + + # Build date + date_str = "" + if days: + date_str = "%sD" % days + + if hours or minutes or seconds: + # Build time + time_str = "T" + + # Hours + bigger_exists = date_str or hours + if bigger_exists: + time_str += "{:02}H".format(hours) + + # Minutes + bigger_exists = bigger_exists or minutes + if bigger_exists: + time_str += "{:02}M".format(minutes) + + # Seconds + try: + if seconds.is_integer(): + seconds_string = "{:02}".format(int(seconds)) + else: + # 9 chars long w/ leading 0, 6 digits after decimal + seconds_string = "%09.6f" % seconds + # Remove trailing zeros + seconds_string = seconds_string.rstrip("0") + except AttributeError: # int.is_integer() raises + seconds_string = "{:02}".format(seconds) + + time_str += "{}S".format(seconds_string) + else: + time_str = "" + + return "P" + date_str + time_str + + +def _serialize_bytes(o, format: typing.Optional[str] = None) -> str: + encoded = base64.b64encode(o).decode() + if format == "base64url": + return encoded.strip("=").replace("+", "-").replace("/", "_") + return encoded + + +def _serialize_datetime(o, format: typing.Optional[str] = None): + if hasattr(o, "year") and hasattr(o, "hour"): + if format == "rfc7231": + return email.utils.format_datetime(o, usegmt=True) + if format == "unix-timestamp": + return int(calendar.timegm(o.utctimetuple())) + + # astimezone() fails for naive times in Python 2.7, so make make sure o is aware (tzinfo is set) + if not o.tzinfo: + iso_formatted = o.replace(tzinfo=TZ_UTC).isoformat() + else: + iso_formatted = o.astimezone(TZ_UTC).isoformat() + # Replace the trailing "+00:00" UTC offset with "Z" (RFC 3339: https://www.ietf.org/rfc/rfc3339.txt) + return iso_formatted.replace("+00:00", "Z") + # Next try datetime.date or datetime.time + return o.isoformat() + + +def _is_readonly(p): + try: + return p._visibility == ["read"] + except AttributeError: + return False + + +class SdkJSONEncoder(JSONEncoder): + """A JSON encoder that's capable of serializing datetime objects and bytes.""" + + def __init__(self, *args, exclude_readonly: bool = False, format: typing.Optional[str] = None, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + self.format = format + + def default(self, o): # pylint: disable=too-many-return-statements + if _is_model(o): + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) + try: + return super(SdkJSONEncoder, self).default(o) + except TypeError: + if isinstance(o, _Null): + return None + if isinstance(o, decimal.Decimal): + return float(o) + if isinstance(o, (bytes, bytearray)): + return _serialize_bytes(o, self.format) + try: + # First try datetime.datetime + return _serialize_datetime(o, self.format) + except AttributeError: + pass + # Last, try datetime.timedelta + try: + return _timedelta_as_isostr(o) + except AttributeError: + # This will be raised when it hits value.total_seconds in the method above + pass + return super(SdkJSONEncoder, self).default(o) + + +_VALID_DATE = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}" + r"\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?") +_VALID_RFC7231 = re.compile( + r"(Mon|Tue|Wed|Thu|Fri|Sat|Sun),\s\d{2}\s" + r"(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s\d{4}\s\d{2}:\d{2}:\d{2}\sGMT" +) + +_ARRAY_ENCODE_MAPPING = { + "pipeDelimited": "|", + "spaceDelimited": " ", + "commaDelimited": ",", + "newlineDelimited": "\n", +} + + +def _deserialize_array_encoded(delimit: str, attr): + if isinstance(attr, str): + if attr == "": + return [] + return attr.split(delimit) + return attr + + +def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime: + """Deserialize ISO-8601 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: ~datetime.datetime + :returns: The datetime object from that input + """ + if isinstance(attr, datetime): + # i'm already deserialized + return attr + attr = attr.upper() + match = _VALID_DATE.match(attr) + if not match: + raise ValueError("Invalid datetime string: " + attr) + + check_decimal = attr.split(".") + if len(check_decimal) > 1: + decimal_str = "" + for digit in check_decimal[1]: + if digit.isdigit(): + decimal_str += digit + else: + break + if len(decimal_str) > 6: + attr = attr.replace(decimal_str, decimal_str[0:6]) + + date_obj = isodate.parse_datetime(attr) + test_utc = date_obj.utctimetuple() + if test_utc.tm_year > 9999 or test_utc.tm_year < 1: + raise OverflowError("Hit max or min date") + return date_obj # type: ignore[no-any-return] + + +def _deserialize_datetime_rfc7231(attr: typing.Union[str, datetime]) -> datetime: + """Deserialize RFC7231 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: ~datetime.datetime + :returns: The datetime object from that input + """ + if isinstance(attr, datetime): + # i'm already deserialized + return attr + match = _VALID_RFC7231.match(attr) + if not match: + raise ValueError("Invalid datetime string: " + attr) + + return email.utils.parsedate_to_datetime(attr) + + +def _deserialize_datetime_unix_timestamp(attr: typing.Union[float, datetime]) -> datetime: + """Deserialize unix timestamp into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: ~datetime.datetime + :returns: The datetime object from that input + """ + if isinstance(attr, datetime): + # i'm already deserialized + return attr + return datetime.fromtimestamp(attr, TZ_UTC) + + +def _deserialize_date(attr: typing.Union[str, date]) -> date: + """Deserialize ISO-8601 formatted string into Date object. + :param str attr: response string to be deserialized. + :rtype: date + :returns: The date object from that input + """ + # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception. + if isinstance(attr, date): + return attr + return isodate.parse_date(attr, defaultmonth=None, defaultday=None) # type: ignore + + +def _deserialize_time(attr: typing.Union[str, time]) -> time: + """Deserialize ISO-8601 formatted string into time object. + + :param str attr: response string to be deserialized. + :rtype: datetime.time + :returns: The time object from that input + """ + if isinstance(attr, time): + return attr + return isodate.parse_time(attr) # type: ignore[no-any-return] + + +def _deserialize_bytes(attr): + if isinstance(attr, (bytes, bytearray)): + return attr + return bytes(base64.b64decode(attr)) + + +def _deserialize_bytes_base64(attr): + if isinstance(attr, (bytes, bytearray)): + return attr + padding = "=" * (3 - (len(attr) + 3) % 4) # type: ignore + attr = attr + padding # type: ignore + encoded = attr.replace("-", "+").replace("_", "/") + return bytes(base64.b64decode(encoded)) + + +def _deserialize_duration(attr): + if isinstance(attr, timedelta): + return attr + return isodate.parse_duration(attr) + + +def _deserialize_decimal(attr): + if isinstance(attr, decimal.Decimal): + return attr + return decimal.Decimal(str(attr)) + + +def _deserialize_int_as_str(attr): + if isinstance(attr, int): + return attr + return int(attr) + + +_DESERIALIZE_MAPPING = { + datetime: _deserialize_datetime, + date: _deserialize_date, + time: _deserialize_time, + bytes: _deserialize_bytes, + bytearray: _deserialize_bytes, + timedelta: _deserialize_duration, + typing.Any: lambda x: x, + decimal.Decimal: _deserialize_decimal, +} + +_DESERIALIZE_MAPPING_WITHFORMAT = { + "rfc3339": _deserialize_datetime, + "rfc7231": _deserialize_datetime_rfc7231, + "unix-timestamp": _deserialize_datetime_unix_timestamp, + "base64": _deserialize_bytes, + "base64url": _deserialize_bytes_base64, +} + + +def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = None): + if annotation is int and rf and rf._format == "str": + return _deserialize_int_as_str + if annotation is str and rf and rf._format in _ARRAY_ENCODE_MAPPING: + return functools.partial(_deserialize_array_encoded, _ARRAY_ENCODE_MAPPING[rf._format]) + if rf and rf._format: + return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format) + return _DESERIALIZE_MAPPING.get(annotation) # pyright: ignore + + +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + +def _get_model(module_name: str, model_name: str): + models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} + module_end = module_name.rsplit(".", 1)[0] + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) + if isinstance(model_name, str): + model_name = model_name.split(".")[-1] + if model_name not in models: + return model_name + return models[model_name] + + +_UNSET = object() + + +class _MyMutableMapping(MutableMapping[str, typing.Any]): + def __init__(self, data: dict[str, typing.Any]) -> None: + self._data = data + + def __contains__(self, key: typing.Any) -> bool: + return key in self._data + + def __getitem__(self, key: str) -> typing.Any: + # If this key has been deserialized (for mutable types), we need to handle serialization + if hasattr(self, "_attr_to_rest_field"): + cache_attr = f"_deserialized_{key}" + if hasattr(self, cache_attr): + rf = _get_rest_field(getattr(self, "_attr_to_rest_field"), key) + if rf: + value = self._data.get(key) + if isinstance(value, (dict, list, set)): + # For mutable types, serialize and return + # But also update _data with serialized form and clear flag + # so mutations via this returned value affect _data + serialized = _serialize(value, rf._format) + # If serialized form is same type (no transformation needed), + # return _data directly so mutations work + if isinstance(serialized, type(value)) and serialized == value: + return self._data.get(key) + # Otherwise return serialized copy and clear flag + try: + object.__delattr__(self, cache_attr) + except AttributeError: + pass + # Store serialized form back + self._data[key] = serialized + return serialized + return self._data.__getitem__(key) + + def __setitem__(self, key: str, value: typing.Any) -> None: + # Clear any cached deserialized value when setting through dictionary access + cache_attr = f"_deserialized_{key}" + try: + object.__delattr__(self, cache_attr) + except AttributeError: + pass + self._data.__setitem__(key, value) + + def __delitem__(self, key: str) -> None: + self._data.__delitem__(key) + + def __iter__(self) -> typing.Iterator[typing.Any]: + return self._data.__iter__() + + def __len__(self) -> int: + return self._data.__len__() + + def __ne__(self, other: typing.Any) -> bool: + return not self.__eq__(other) + + def keys(self) -> typing.KeysView[str]: + """ + :returns: a set-like object providing a view on D's keys + :rtype: ~typing.KeysView + """ + return self._data.keys() + + def values(self) -> typing.ValuesView[typing.Any]: + """ + :returns: an object providing a view on D's values + :rtype: ~typing.ValuesView + """ + return self._data.values() + + def items(self) -> typing.ItemsView[str, typing.Any]: + """ + :returns: set-like object providing a view on D's items + :rtype: ~typing.ItemsView + """ + return self._data.items() + + def get(self, key: str, default: typing.Any = None) -> typing.Any: + """ + Get the value for key if key is in the dictionary, else default. + :param str key: The key to look up. + :param any default: The value to return if key is not in the dictionary. Defaults to None + :returns: D[k] if k in D, else d. + :rtype: any + """ + try: + return self[key] + except KeyError: + return default + + @typing.overload + def pop(self, key: str) -> typing.Any: ... # pylint: disable=arguments-differ + + @typing.overload + def pop(self, key: str, default: _T) -> _T: ... # pylint: disable=signature-differs + + @typing.overload + def pop(self, key: str, default: typing.Any) -> typing.Any: ... # pylint: disable=signature-differs + + def pop(self, key: str, default: typing.Any = _UNSET) -> typing.Any: + """ + Removes specified key and return the corresponding value. + :param str key: The key to pop. + :param any default: The value to return if key is not in the dictionary + :returns: The value corresponding to the key. + :rtype: any + :raises KeyError: If key is not found and default is not given. + """ + if default is _UNSET: + return self._data.pop(key) + return self._data.pop(key, default) + + def popitem(self) -> tuple[str, typing.Any]: + """ + Removes and returns some (key, value) pair + :returns: The (key, value) pair. + :rtype: tuple + :raises KeyError: if D is empty. + """ + return self._data.popitem() + + def clear(self) -> None: + """ + Remove all items from D. + """ + self._data.clear() + + def update(self, *args: typing.Any, **kwargs: typing.Any) -> None: # pylint: disable=arguments-differ + """ + Updates D from mapping/iterable E and F. + :param any args: Either a mapping object or an iterable of key-value pairs. + """ + self._data.update(*args, **kwargs) + + @typing.overload + def setdefault(self, key: str, default: None = None) -> None: ... + + @typing.overload + def setdefault(self, key: str, default: typing.Any) -> typing.Any: ... # pylint: disable=signature-differs + + def setdefault(self, key: str, default: typing.Any = _UNSET) -> typing.Any: + """ + Same as calling D.get(k, d), and setting D[k]=d if k not found + :param str key: The key to look up. + :param any default: The value to set if key is not in the dictionary + :returns: D[k] if k in D, else d. + :rtype: any + """ + if default is _UNSET: + return self._data.setdefault(key) + return self._data.setdefault(key, default) + + def __eq__(self, other: typing.Any) -> bool: + if isinstance(other, _MyMutableMapping): + return self._data == other._data + try: + other_model = self.__class__(other) + except Exception: + return False + return self._data == other_model._data + + def __repr__(self) -> str: + return str(self._data) + + +def _is_model(obj: typing.Any) -> bool: + return getattr(obj, "_is_model", False) + + +def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-many-return-statements + if isinstance(o, list): + if format in _ARRAY_ENCODE_MAPPING and all(isinstance(x, str) for x in o): + return _ARRAY_ENCODE_MAPPING[format].join(o) + return [_serialize(x, format) for x in o] + if isinstance(o, dict): + return {k: _serialize(v, format) for k, v in o.items()} + if isinstance(o, set): + return {_serialize(x, format) for x in o} + if isinstance(o, tuple): + return tuple(_serialize(x, format) for x in o) + if isinstance(o, (bytes, bytearray)): + return _serialize_bytes(o, format) + if isinstance(o, decimal.Decimal): + return float(o) + if isinstance(o, enum.Enum): + return o.value + if isinstance(o, int): + if format == "str": + return str(o) + return o + try: + # First try datetime.datetime + return _serialize_datetime(o, format) + except AttributeError: + pass + # Last, try datetime.timedelta + try: + return _timedelta_as_isostr(o) + except AttributeError: + # This will be raised when it hits value.total_seconds in the method above + pass + return o + + +def _get_rest_field(attr_to_rest_field: dict[str, "_RestField"], rest_name: str) -> typing.Optional["_RestField"]: + try: + return next(rf for rf in attr_to_rest_field.values() if rf._rest_name == rest_name) + except StopIteration: + return None + + +def _create_value(rf: typing.Optional["_RestField"], value: typing.Any) -> typing.Any: + if not rf: + return _serialize(value, None) + if rf._is_multipart_file_input: + return value + if rf._is_model: + return _deserialize(rf._type, value) + if isinstance(value, ET.Element): + value = _deserialize(rf._type, value) + return _serialize(value, rf._format) + + +class Model(_MyMutableMapping): + _is_model = True + # label whether current class's _attr_to_rest_field has been calculated + # could not see _attr_to_rest_field directly because subclass inherits it from parent class + _calculated: set[str] = set() + + def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: + class_name = self.__class__.__name__ + if len(args) > 1: + raise TypeError(f"{class_name}.__init__() takes 2 positional arguments but {len(args) + 1} were given") + dict_to_pass = { + rest_field._rest_name: rest_field._default + for rest_field in self._attr_to_rest_field.values() + if rest_field._default is not _UNSET + } + if args: + if isinstance(args[0], ET.Element): + dict_to_pass.update(self._init_from_xml(args[0])) + else: + dict_to_pass.update( + {k: _create_value(_get_rest_field(self._attr_to_rest_field, k), v) for k, v in args[0].items()} + ) + else: + non_attr_kwargs = [k for k in kwargs if k not in self._attr_to_rest_field] + if non_attr_kwargs: + # actual type errors only throw the first wrong keyword arg they see, so following that. + raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") + dict_to_pass.update( + { + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) + for k, v in kwargs.items() + if v is not None + } + ) + super().__init__(dict_to_pass) + + def _init_from_xml(self, element: ET.Element) -> dict[str, typing.Any]: + """Deserialize an XML element into a dict mapping rest field names to values. + + :param ET.Element element: The XML element to deserialize from. + :returns: A dictionary of rest_name to deserialized value pairs. + :rtype: dict + """ + result: dict[str, typing.Any] = {} + model_meta = getattr(self, "_xml", {}) + existed_attr_keys: list[str] = [] + + for rf in self._attr_to_rest_field.values(): + prop_meta = getattr(rf, "_xml", {}) + xml_name = prop_meta.get("name", rf._rest_name) + xml_ns = _resolve_xml_ns(prop_meta, model_meta) + if xml_ns: + xml_name = "{" + xml_ns + "}" + xml_name + + # attribute + if prop_meta.get("attribute", False) and element.get(xml_name) is not None: + existed_attr_keys.append(xml_name) + result[rf._rest_name] = _deserialize(rf._type, element.get(xml_name)) + continue + + # unwrapped element is array + if prop_meta.get("unwrapped", False): + # unwrapped array could either use prop items meta/prop meta + _items_name = prop_meta.get("itemsName") + if _items_name: + xml_name = _items_name + _items_ns = prop_meta.get("itemsNs") + if _items_ns is not None: + xml_ns = _items_ns + if xml_ns: + xml_name = "{" + xml_ns + "}" + xml_name + items = element.findall(xml_name) # pyright: ignore + if len(items) > 0: + existed_attr_keys.append(xml_name) + result[rf._rest_name] = _deserialize(rf._type, items) + elif not rf._is_optional: + existed_attr_keys.append(xml_name) + result[rf._rest_name] = [] + continue + + # text element is primitive type + if prop_meta.get("text", False): + if element.text is not None: + result[rf._rest_name] = _deserialize(rf._type, element.text) + continue + + # wrapped element could be normal property or array, it should only have one element + item = element.find(xml_name) + if item is not None: + existed_attr_keys.append(xml_name) + result[rf._rest_name] = _deserialize(rf._type, item) + + # rest thing is additional properties + for e in element: + if e.tag not in existed_attr_keys: + result[e.tag] = _convert_element(e) + + return result + + def copy(self) -> "Model": + return Model(self.__dict__) + + def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self: + if f"{cls.__module__}.{cls.__qualname__}" not in cls._calculated: + # we know the last nine classes in mro are going to be 'Model', '_MyMutableMapping', 'MutableMapping', + # 'Mapping', 'Collection', 'Sized', 'Iterable', 'Container' and 'object' + mros = cls.__mro__[:-9][::-1] # ignore parents, and reverse the mro order + attr_to_rest_field: dict[str, _RestField] = { # map attribute name to rest_field property + k: v for mro_class in mros for k, v in mro_class.__dict__.items() if k[0] != "_" and hasattr(v, "_type") + } + annotations = { + k: v + for mro_class in mros + if hasattr(mro_class, "__annotations__") + for k, v in mro_class.__annotations__.items() + } + for attr, rf in attr_to_rest_field.items(): + rf._module = cls.__module__ + if not rf._type: + rf._type = rf._get_deserialize_callable_from_annotation(annotations.get(attr, None)) + if not rf._rest_name_input: + rf._rest_name_input = attr + cls._attr_to_rest_field: dict[str, _RestField] = dict(attr_to_rest_field.items()) + cls._calculated.add(f"{cls.__module__}.{cls.__qualname__}") + + return super().__new__(cls) + + def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: + for base in cls.__bases__: + if hasattr(base, "__mapping__"): + base.__mapping__[discriminator or cls.__name__] = cls # type: ignore + + @classmethod + def _get_discriminator(cls, exist_discriminators) -> typing.Optional["_RestField"]: + for v in cls.__dict__.values(): + if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators: + return v + return None + + @classmethod + def _deserialize(cls, data, exist_discriminators): + if not hasattr(cls, "__mapping__"): + return cls(data) + discriminator = cls._get_discriminator(exist_discriminators) + if discriminator is None: + return cls(data) + exist_discriminators.append(discriminator._rest_name) + if isinstance(data, ET.Element): + model_meta = getattr(cls, "_xml", {}) + prop_meta = getattr(discriminator, "_xml", {}) + xml_name = prop_meta.get("name", discriminator._rest_name) + xml_ns = _resolve_xml_ns(prop_meta, model_meta) + if xml_ns: + xml_name = "{" + xml_ns + "}" + xml_name + + if data.get(xml_name) is not None: + discriminator_value = data.get(xml_name) + else: + discriminator_value = data.find(xml_name).text # pyright: ignore + else: + discriminator_value = data.get(discriminator._rest_name) + mapped_cls = cls.__mapping__.get(discriminator_value, cls) # pyright: ignore # pylint: disable=no-member + return mapped_cls._deserialize(data, exist_discriminators) + + def as_dict(self, *, exclude_readonly: bool = False) -> dict[str, typing.Any]: + """Return a dict that can be turned into json using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + readonly_props = [] + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: ignore + continue + is_multipart_file_input = False + try: + is_multipart_file_input = next( + rf for rf in self._attr_to_rest_field.values() if rf._rest_name == k + )._is_multipart_file_input + except StopIteration: + pass + result[k] = v if is_multipart_file_input else Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return type(v)(Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v) + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v + + +def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj): + if _is_model(obj): + return obj + return _deserialize(model_deserializer, obj) + + +def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Callable], obj): + if obj is None: + return obj + return _deserialize_with_callable(if_obj_deserializer, obj) + + +def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + +def _deserialize_dict( + value_deserializer: typing.Optional[typing.Callable], + module: typing.Optional[str], + obj: dict[typing.Any, typing.Any], +): + if obj is None: + return obj + if isinstance(obj, ET.Element): + obj = {child.tag: child for child in obj} + return {k: _deserialize(value_deserializer, v, module) for k, v in obj.items()} + + +def _deserialize_multiple_sequence( + entry_deserializers: list[typing.Optional[typing.Callable]], + module: typing.Optional[str], + obj, +): + if obj is None: + return obj + return type(obj)(_deserialize(deserializer, entry, module) for entry, deserializer in zip(obj, entry_deserializers)) + + +def _is_array_encoded_deserializer(deserializer: functools.partial) -> bool: + return ( + isinstance(deserializer, functools.partial) + and isinstance(deserializer.args[0], functools.partial) + and deserializer.args[0].func == _deserialize_array_encoded # pylint: disable=comparison-with-callable + ) + + +def _deserialize_sequence( + deserializer: typing.Optional[typing.Callable], + module: typing.Optional[str], + obj, +): + if obj is None: + return obj + if isinstance(obj, ET.Element): + obj = list(obj) + + # encoded string may be deserialized to sequence + if isinstance(obj, str) and isinstance(deserializer, functools.partial): + # for list[str] + if _is_array_encoded_deserializer(deserializer): + return deserializer(obj) + + # for list[Union[...]] + if isinstance(deserializer.args[0], list): + for sub_deserializer in deserializer.args[0]: + if _is_array_encoded_deserializer(sub_deserializer): + return sub_deserializer(obj) + + return type(obj)(_deserialize(deserializer, entry, module) for entry in obj) + + +def _sorted_annotations(types: list[typing.Any]) -> list[typing.Any]: + return sorted( + types, + key=lambda x: hasattr(x, "__name__") and x.__name__.lower() in ("str", "float", "int", "bool"), + ) + + +def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements, too-many-branches + annotation: typing.Any, + module: typing.Optional[str], + rf: typing.Optional["_RestField"] = None, +) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]: + if not annotation: + return None + + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) # type: ignore + + try: + if module and _is_model(annotation): + if rf: + rf._is_model = True + + return functools.partial(_deserialize_model, annotation) # pyright: ignore + except Exception: + pass + + # is it a literal? + try: + if annotation.__origin__ is typing.Literal: # pyright: ignore + return None + except AttributeError: + pass + + # is it optional? + try: + if any(a is _NONE_TYPE for a in annotation.__args__): # pyright: ignore + if rf: + rf._is_optional = True + if len(annotation.__args__) <= 2: # pyright: ignore + if_obj_deserializer = _get_deserialize_callable_from_annotation( + next(a for a in annotation.__args__ if a is not _NONE_TYPE), module, rf # pyright: ignore + ) + + return functools.partial(_deserialize_with_optional, if_obj_deserializer) + # the type is Optional[Union[...]], we need to remove the None type from the Union + annotation_copy = copy.copy(annotation) + annotation_copy.__args__ = [a for a in annotation_copy.__args__ if a is not _NONE_TYPE] # pyright: ignore + return _get_deserialize_callable_from_annotation(annotation_copy, module, rf) + except AttributeError: + pass + + # is it union? + if getattr(annotation, "__origin__", None) is typing.Union: + # initial ordering is we make `string` the last deserialization option, because it is often them most generic + deserializers = [ + _get_deserialize_callable_from_annotation(arg, module, rf) + for arg in _sorted_annotations(annotation.__args__) # pyright: ignore + ] + + return functools.partial(_deserialize_with_union, deserializers) + + try: + annotation_name = ( + annotation.__name__ if hasattr(annotation, "__name__") else annotation._name # pyright: ignore + ) + if annotation_name.lower() == "dict": + value_deserializer = _get_deserialize_callable_from_annotation( + annotation.__args__[1], module, rf # pyright: ignore + ) + + return functools.partial( + _deserialize_dict, + value_deserializer, + module, + ) + except (AttributeError, IndexError): + pass + try: + annotation_name = ( + annotation.__name__ if hasattr(annotation, "__name__") else annotation._name # pyright: ignore + ) + if annotation_name.lower() in ["list", "set", "tuple", "sequence"]: + if len(annotation.__args__) > 1: # pyright: ignore + entry_deserializers = [ + _get_deserialize_callable_from_annotation(dt, module, rf) + for dt in annotation.__args__ # pyright: ignore + ] + return functools.partial(_deserialize_multiple_sequence, entry_deserializers, module) + deserializer = _get_deserialize_callable_from_annotation( + annotation.__args__[0], module, rf # pyright: ignore + ) + + return functools.partial(_deserialize_sequence, deserializer, module) + except (TypeError, IndexError, AttributeError, SyntaxError): + pass + + def _deserialize_default( + deserializer, + obj, + ): + if obj is None: + return obj + try: + return _deserialize_with_callable(deserializer, obj) + except Exception: + pass + return obj + + if get_deserializer(annotation, rf): + return functools.partial(_deserialize_default, get_deserializer(annotation, rf)) + + return functools.partial(_deserialize_default, annotation) + + +def _deserialize_with_callable( + deserializer: typing.Optional[typing.Callable[[typing.Any], typing.Any]], + value: typing.Any, +): # pylint: disable=too-many-return-statements + try: + if value is None or isinstance(value, _Null): + return None + if isinstance(value, ET.Element): + if deserializer is str: + return value.text or "" + if deserializer is int: + return int(value.text) if value.text else None + if deserializer is float: + return float(value.text) if value.text else None + if deserializer is bool: + return value.text == "true" if value.text else None + if deserializer and deserializer in _DESERIALIZE_MAPPING.values(): + return deserializer(value.text) if value.text else None + if deserializer and deserializer in _DESERIALIZE_MAPPING_WITHFORMAT.values(): + return deserializer(value.text) if value.text else None + if deserializer is None: + return value + if deserializer in [int, float, bool]: + return deserializer(value) + if isinstance(deserializer, CaseInsensitiveEnumMeta): + try: + return deserializer(value.text if isinstance(value, ET.Element) else value) + except ValueError: + # for unknown value, return raw value + return value.text if isinstance(value, ET.Element) else value + if isinstance(deserializer, type) and issubclass(deserializer, Model): + return deserializer._deserialize(value, []) + return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) + except Exception as e: + raise DeserializationError() from e + + +def _deserialize( + deserializer: typing.Any, + value: typing.Any, + module: typing.Optional[str] = None, + rf: typing.Optional["_RestField"] = None, + format: typing.Optional[str] = None, +) -> typing.Any: + if isinstance(value, PipelineResponse): + value = value.http_response.json() + if rf is None and format: + rf = _RestField(format=format) + if not isinstance(deserializer, functools.partial): + deserializer = _get_deserialize_callable_from_annotation(deserializer, module, rf) + return _deserialize_with_callable(deserializer, value) + + +def _failsafe_deserialize( + deserializer: typing.Any, + response: HttpResponse, + module: typing.Optional[str] = None, + rf: typing.Optional["_RestField"] = None, + format: typing.Optional[str] = None, +) -> typing.Any: + try: + return _deserialize(deserializer, response.json(), module, rf, format) + except Exception: # pylint: disable=broad-except + _LOGGER.warning( + "Ran into a deserialization error. Ignoring since this is failsafe deserialization", exc_info=True + ) + return None + + +def _failsafe_deserialize_xml( + deserializer: typing.Any, + response: HttpResponse, +) -> typing.Any: + try: + return _deserialize_xml(deserializer, response.text()) + except Exception: # pylint: disable=broad-except + _LOGGER.warning( + "Ran into a deserialization error. Ignoring since this is failsafe deserialization", exc_info=True + ) + return None + + +# pylint: disable=too-many-instance-attributes +class _RestField: + def __init__( + self, + *, + name: typing.Optional[str] = None, + type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin + is_discriminator: bool = False, + visibility: typing.Optional[list[str]] = None, + default: typing.Any = _UNSET, + format: typing.Optional[str] = None, + is_multipart_file_input: bool = False, + xml: typing.Optional[dict[str, typing.Any]] = None, + ): + self._type = type + self._rest_name_input = name + self._module: typing.Optional[str] = None + self._is_discriminator = is_discriminator + self._visibility = visibility + self._is_model = False + self._is_optional = False + self._default = default + self._format = format + self._is_multipart_file_input = is_multipart_file_input + self._xml = xml if xml is not None else {} + + @property + def _class_type(self) -> typing.Any: + result = getattr(self._type, "args", [None])[0] + # type may be wrapped by nested functools.partial so we need to check for that + if isinstance(result, functools.partial): + return getattr(result, "args", [None])[0] + return result + + @property + def _rest_name(self) -> str: + if self._rest_name_input is None: + raise ValueError("Rest name was never set") + return self._rest_name_input + + def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin + # by this point, type and rest_name will have a value bc we default + # them in __new__ of the Model class + # Use _data.get() directly to avoid triggering __getitem__ which clears the cache + item = obj._data.get(self._rest_name) + if item is None: + return item + if self._is_model: + return item + + # For mutable types, we want mutations to directly affect _data + # Check if we've already deserialized this value + cache_attr = f"_deserialized_{self._rest_name}" + if hasattr(obj, cache_attr): + # Return the value from _data directly (it's been deserialized in place) + return obj._data.get(self._rest_name) + + deserialized = _deserialize(self._type, _serialize(item, self._format), rf=self) + + # For mutable types, store the deserialized value back in _data + # so mutations directly affect _data + if isinstance(deserialized, (dict, list, set)): + obj._data[self._rest_name] = deserialized + object.__setattr__(obj, cache_attr, True) # Mark as deserialized + return deserialized + + return deserialized + + def __set__(self, obj: Model, value) -> None: + # Clear the cached deserialized object when setting a new value + cache_attr = f"_deserialized_{self._rest_name}" + if hasattr(obj, cache_attr): + object.__delattr__(obj, cache_attr) + + if value is None: + # we want to wipe out entries if users set attr to None + try: + obj.__delitem__(self._rest_name) + except KeyError: + pass + return + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return + obj.__setitem__(self._rest_name, _serialize(value, self._format)) + + def _get_deserialize_callable_from_annotation( + self, annotation: typing.Any + ) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]: + return _get_deserialize_callable_from_annotation(annotation, self._module, self) + + +def rest_field( + *, + name: typing.Optional[str] = None, + type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin + visibility: typing.Optional[list[str]] = None, + default: typing.Any = _UNSET, + format: typing.Optional[str] = None, + is_multipart_file_input: bool = False, + xml: typing.Optional[dict[str, typing.Any]] = None, +) -> typing.Any: + return _RestField( + name=name, + type=type, + visibility=visibility, + default=default, + format=format, + is_multipart_file_input=is_multipart_file_input, + xml=xml, + ) + + +def rest_discriminator( + *, + name: typing.Optional[str] = None, + type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin + visibility: typing.Optional[list[str]] = None, + xml: typing.Optional[dict[str, typing.Any]] = None, +) -> typing.Any: + return _RestField(name=name, type=type, is_discriminator=True, visibility=visibility, xml=xml) + + +def serialize_xml(model: Model, exclude_readonly: bool = False) -> str: + """Serialize a model to XML. + + :param Model model: The model to serialize. + :param bool exclude_readonly: Whether to exclude readonly properties. + :returns: The XML representation of the model. + :rtype: str + """ + return ET.tostring(_get_element(model, exclude_readonly), encoding="unicode") # type: ignore + + +def _get_xml_ns(meta: dict[str, typing.Any]) -> typing.Optional[str]: + """Return the XML namespace from a metadata dict, checking both 'ns' (old-style) and 'namespace' (DPG) keys. + + :param dict meta: The metadata dictionary to extract namespace from. + :returns: The namespace string if 'ns' or 'namespace' key is present, None otherwise. + :rtype: str or None + """ + ns = meta.get("ns") + if ns is None: + ns = meta.get("namespace") + return ns + + +def _resolve_xml_ns( + prop_meta: dict[str, typing.Any], model_meta: typing.Optional[dict[str, typing.Any]] = None +) -> typing.Optional[str]: + """Resolve XML namespace for a property, falling back to model namespace when appropriate. + + Checks the property metadata first; if no namespace is found and the model does not declare + an explicit prefix, falls back to the model-level namespace. + + :param dict prop_meta: The property metadata dictionary. + :param dict model_meta: The model metadata dictionary, used as fallback. + :returns: The resolved namespace string, or None. + :rtype: str or None + """ + ns = _get_xml_ns(prop_meta) + if ns is None and model_meta is not None and not model_meta.get("prefix"): + ns = _get_xml_ns(model_meta) + return ns + + +def _set_xml_attribute(element: ET.Element, name: str, value: typing.Any, prop_meta: dict[str, typing.Any]) -> None: + """Set an XML attribute on an element, handling namespace prefix registration. + + :param ET.Element element: The element to set the attribute on. + :param str name: The default attribute name (wire name). + :param any value: The attribute value. + :param dict prop_meta: The property metadata dictionary. + """ + xml_name = prop_meta.get("name", name) + _attr_ns = _get_xml_ns(prop_meta) + if _attr_ns: + _attr_prefix = prop_meta.get("prefix") + if _attr_prefix: + _safe_register_namespace(_attr_prefix, _attr_ns) + xml_name = "{" + _attr_ns + "}" + xml_name + element.set(xml_name, _get_primitive_type_value(value)) + + +def _get_element( + o: typing.Any, + exclude_readonly: bool = False, + parent_meta: typing.Optional[dict[str, typing.Any]] = None, + wrapped_element: typing.Optional[ET.Element] = None, +) -> typing.Union[ET.Element, list[ET.Element]]: + if _is_model(o): + model_meta = getattr(o, "_xml", {}) + + # if prop is a model, then use the prop element directly, else generate a wrapper of model + if wrapped_element is None: + # When serializing as an array item (parent_meta is set), check if the parent has an + # explicit itemsName. This ensures correct element names for unwrapped arrays (where + # the element tag is the property/items name, not the model type name). + _items_name = parent_meta.get("itemsName") if parent_meta is not None else None + element_name = _items_name if _items_name else (model_meta.get("name") or o.__class__.__name__) + _model_ns = _get_xml_ns(model_meta) + wrapped_element = _create_xml_element( + element_name, + model_meta.get("prefix"), + _model_ns, + ) + + readonly_props = [] + if exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + + for k, v in o.items(): + # do not serialize readonly properties + if exclude_readonly and k in readonly_props: + continue + + prop_rest_field = _get_rest_field(o._attr_to_rest_field, k) + if prop_rest_field: + prop_meta = getattr(prop_rest_field, "_xml").copy() + # use the wire name as xml name if no specific name is set + if prop_meta.get("name") is None: + prop_meta["name"] = k + else: + # additional properties will not have rest field, use the wire name as xml name + prop_meta = {"name": k} + + # Propagate model namespace to properties only for old-style "ns"-keyed models. + # DPG-generated models use the "namespace" key and explicitly declare namespace on + # each property that needs it, so propagation is intentionally skipped for them. + if prop_meta.get("ns") is None and model_meta.get("ns"): + prop_meta["ns"] = model_meta.get("ns") + prop_meta["prefix"] = model_meta.get("prefix") + + if prop_meta.get("unwrapped", False): + # unwrapped could only set on array + wrapped_element.extend(_get_element(v, exclude_readonly, prop_meta)) + elif prop_meta.get("text", False): + # text could only set on primitive type + wrapped_element.text = _get_primitive_type_value(v) + elif prop_meta.get("attribute", False): + _set_xml_attribute(wrapped_element, k, v, prop_meta) + else: + # other wrapped prop element + wrapped_element.append(_get_wrapped_element(v, exclude_readonly, prop_meta)) + return wrapped_element + if isinstance(o, list): + return [_get_element(x, exclude_readonly, parent_meta) for x in o] # type: ignore + if isinstance(o, dict): + result = [] + _dict_ns = _get_xml_ns(parent_meta) if parent_meta else None + for k, v in o.items(): + result.append( + _get_wrapped_element( + v, + exclude_readonly, + { + "name": k, + "ns": _dict_ns, + "prefix": parent_meta.get("prefix") if parent_meta else None, + }, + ) + ) + return result + + # primitive case need to create element based on parent_meta + if parent_meta: + _items_ns = parent_meta.get("itemsNs") + if _items_ns is None: + _items_ns = _get_xml_ns(parent_meta) + return _get_wrapped_element( + o, + exclude_readonly, + { + "name": parent_meta.get("itemsName", parent_meta.get("name")), + "prefix": parent_meta.get("itemsPrefix", parent_meta.get("prefix")), + "ns": _items_ns, + }, + ) + + raise ValueError("Could not serialize value into xml: " + o) + + +def _get_wrapped_element( + v: typing.Any, + exclude_readonly: bool, + meta: typing.Optional[dict[str, typing.Any]], +) -> ET.Element: + _meta_ns = _get_xml_ns(meta) if meta else None + wrapped_element = _create_xml_element( + meta.get("name") if meta else None, meta.get("prefix") if meta else None, _meta_ns + ) + if isinstance(v, (dict, list)): + wrapped_element.extend(_get_element(v, exclude_readonly, meta)) + elif _is_model(v): + _get_element(v, exclude_readonly, meta, wrapped_element) + else: + wrapped_element.text = _get_primitive_type_value(v) + return wrapped_element # type: ignore[no-any-return] + + +def _get_primitive_type_value(v) -> str: + if v is True: + return "true" + if v is False: + return "false" + if isinstance(v, _Null): + return "" + return str(v) + + +def _safe_register_namespace(prefix: str, ns: str) -> None: + """Register an XML namespace prefix, handling reserved prefix patterns. + + Some prefixes (e.g. 'ns2') match Python's reserved 'ns\\d+' pattern used for + auto-generated prefixes, causing register_namespace to raise ValueError. + Falls back to directly registering in the internal namespace map. + + :param str prefix: The namespace prefix to register. + :param str ns: The namespace URI. + """ + try: + ET.register_namespace(prefix, ns) + except ValueError: + _ns_map = getattr(ET, "_namespace_map", None) + if _ns_map is not None: + _ns_map[ns] = prefix + + +def _create_xml_element( + tag: typing.Any, prefix: typing.Optional[str] = None, ns: typing.Optional[str] = None +) -> ET.Element: + if prefix and ns: + _safe_register_namespace(prefix, ns) + if ns: + return ET.Element("{" + ns + "}" + tag) + return ET.Element(tag) + + +def _deserialize_xml( + deserializer: typing.Any, + value: str, +) -> typing.Any: + element = ET.fromstring(value) # nosec + return _deserialize(deserializer, element) + + +def _convert_element(e: ET.Element): + # dict case + if len(e.attrib) > 0 or len({child.tag for child in e}) > 1: + dict_result: dict[str, typing.Any] = {} + for child in e: + if dict_result.get(child.tag) is not None: + if isinstance(dict_result[child.tag], list): + dict_result[child.tag].append(_convert_element(child)) + else: + dict_result[child.tag] = [dict_result[child.tag], _convert_element(child)] + else: + dict_result[child.tag] = _convert_element(child) + dict_result.update(e.attrib) + return dict_result + # array case + if len(e) > 0: + array_result: list[typing.Any] = [] + for child in e: + array_result.append(_convert_element(child)) + return array_result + # primitive case + return e.text diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_utils/serialization.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_utils/serialization.py new file mode 100644 index 000000000000..81ec1de5922b --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_utils/serialization.py @@ -0,0 +1,2041 @@ +# pylint: disable=line-too-long,useless-suppression,too-many-lines +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +# pyright: reportUnnecessaryTypeIgnoreComment=false + +from base64 import b64decode, b64encode +import calendar +import datetime +import decimal +import email +from enum import Enum +import json +import logging +import re +import sys +import codecs +from typing import ( + Any, + cast, + Optional, + Union, + AnyStr, + IO, + Mapping, + Callable, + MutableMapping, +) + +try: + from urllib import quote # type: ignore +except ImportError: + from urllib.parse import quote +import xml.etree.ElementTree as ET + +import isodate # type: ignore +from typing_extensions import Self + +from azure.core.exceptions import DeserializationError, SerializationError +from azure.core.serialization import NULL as CoreNull + +_BOM = codecs.BOM_UTF8.decode(encoding="utf-8") + +JSON = MutableMapping[str, Any] + + +class RawDeserializer: + + # Accept "text" because we're open minded people... + JSON_REGEXP = re.compile(r"^(application|text)/([a-z+.]+\+)?json$") + + # Name used in context + CONTEXT_NAME = "deserialized_data" + + @classmethod + def deserialize_from_text(cls, data: Optional[Union[AnyStr, IO]], content_type: Optional[str] = None) -> Any: + """Decode data according to content-type. + + Accept a stream of data as well, but will be load at once in memory for now. + + If no content-type, will return the string version (not bytes, not stream) + + :param data: Input, could be bytes or stream (will be decoded with UTF8) or text + :type data: str or bytes or IO + :param str content_type: The content type. + :return: The deserialized data. + :rtype: object + """ + if hasattr(data, "read"): + # Assume a stream + data = cast(IO, data).read() + + if isinstance(data, bytes): + data_as_str = data.decode(encoding="utf-8-sig") + else: + # Explain to mypy the correct type. + data_as_str = cast(str, data) + + # Remove Byte Order Mark if present in string + data_as_str = data_as_str.lstrip(_BOM) + + if content_type is None: + return data + + if cls.JSON_REGEXP.match(content_type): + try: + return json.loads(data_as_str) + except ValueError as err: + raise DeserializationError("JSON is invalid: {}".format(err), err) from err + elif "xml" in (content_type or []): + try: + + try: + if isinstance(data, unicode): # type: ignore + # If I'm Python 2.7 and unicode XML will scream if I try a "fromstring" on unicode string + data_as_str = data_as_str.encode(encoding="utf-8") # type: ignore + except NameError: + pass + + return ET.fromstring(data_as_str) # nosec + except ET.ParseError as err: + # It might be because the server has an issue, and returned JSON with + # content-type XML.... + # So let's try a JSON load, and if it's still broken + # let's flow the initial exception + def _json_attemp(data): + try: + return True, json.loads(data) + except ValueError: + return False, None # Don't care about this one + + success, json_result = _json_attemp(data) + if success: + return json_result + # If i'm here, it's not JSON, it's not XML, let's scream + # and raise the last context in this block (the XML exception) + # The function hack is because Py2.7 messes up with exception + # context otherwise. + _LOGGER.critical("Wasn't XML not JSON, failing") + raise DeserializationError("XML is invalid") from err + elif content_type.startswith("text/"): + return data_as_str + raise DeserializationError("Cannot deserialize content-type: {}".format(content_type)) + + @classmethod + def deserialize_from_http_generics(cls, body_bytes: Optional[Union[AnyStr, IO]], headers: Mapping) -> Any: + """Deserialize from HTTP response. + + Use bytes and headers to NOT use any requests/aiohttp or whatever + specific implementation. + Headers will tested for "content-type" + + :param bytes body_bytes: The body of the response. + :param dict headers: The headers of the response. + :returns: The deserialized data. + :rtype: object + """ + # Try to use content-type from headers if available + content_type = None + if "content-type" in headers: + content_type = headers["content-type"].split(";")[0].strip().lower() + # Ouch, this server did not declare what it sent... + # Let's guess it's JSON... + # Also, since Autorest was considering that an empty body was a valid JSON, + # need that test as well.... + else: + content_type = "application/json" + + if body_bytes: + return cls.deserialize_from_text(body_bytes, content_type) + return None + + +_LOGGER = logging.getLogger(__name__) + +try: + _long_type = long # type: ignore +except NameError: + _long_type = int + +TZ_UTC = datetime.timezone.utc + +_FLATTEN = re.compile(r"(? None: + self.additional_properties: Optional[dict[str, Any]] = {} + for k in kwargs: # pylint: disable=consider-using-dict-items + if k not in self._attribute_map: + _LOGGER.warning("%s is not a known attribute of class %s and will be ignored", k, self.__class__) + elif k in self._validation and self._validation[k].get("readonly", False): + _LOGGER.warning("Readonly attribute %s will be ignored in class %s", k, self.__class__) + else: + setattr(self, k, kwargs[k]) + + def __eq__(self, other: Any) -> bool: + """Compare objects by comparing all attributes. + + :param object other: The object to compare + :returns: True if objects are equal + :rtype: bool + """ + if isinstance(other, self.__class__): + return self.__dict__ == other.__dict__ + return False + + def __ne__(self, other: Any) -> bool: + """Compare objects by comparing all attributes. + + :param object other: The object to compare + :returns: True if objects are not equal + :rtype: bool + """ + return not self.__eq__(other) + + def __str__(self) -> str: + return str(self.__dict__) + + @classmethod + def enable_additional_properties_sending(cls) -> None: + cls._attribute_map["additional_properties"] = {"key": "", "type": "{object}"} + + @classmethod + def is_xml_model(cls) -> bool: + try: + cls._xml_map # type: ignore + except AttributeError: + return False + return True + + @classmethod + def _create_xml_node(cls): + """Create XML node. + + :returns: The XML node + :rtype: xml.etree.ElementTree.Element + """ + try: + xml_map = cls._xml_map # type: ignore + except AttributeError: + xml_map = {} + + return _create_xml_node(xml_map.get("name", cls.__name__), xml_map.get("prefix", None), xml_map.get("ns", None)) + + def serialize(self, keep_readonly: bool = False, **kwargs: Any) -> JSON: + """Return the JSON that would be sent to server from this model. + + This is an alias to `as_dict(full_restapi_key_transformer, keep_readonly=False)`. + + If you want XML serialization, you can pass the kwargs is_xml=True. + + :param bool keep_readonly: If you want to serialize the readonly attributes + :returns: A dict JSON compatible object + :rtype: dict + """ + serializer = Serializer(self._infer_class_models()) + return serializer._serialize( # type: ignore # pylint: disable=protected-access + self, keep_readonly=keep_readonly, **kwargs + ) + + def as_dict( + self, + keep_readonly: bool = True, + key_transformer: Callable[[str, dict[str, Any], Any], Any] = attribute_transformer, + **kwargs: Any + ) -> JSON: + """Return a dict that can be serialized using json.dump. + + Advanced usage might optionally use a callback as parameter: + + .. code::python + + def my_key_transformer(key, attr_desc, value): + return key + + Key is the attribute name used in Python. Attr_desc + is a dict of metadata. Currently contains 'type' with the + msrest type and 'key' with the RestAPI encoded key. + Value is the current value in this object. + + The string returned will be used to serialize the key. + If the return type is a list, this is considered hierarchical + result dict. + + See the three examples in this file: + + - attribute_transformer + - full_restapi_key_transformer + - last_restapi_key_transformer + + If you want XML serialization, you can pass the kwargs is_xml=True. + + :param bool keep_readonly: If you want to serialize the readonly attributes + :param function key_transformer: A key transformer function. + :returns: A dict JSON compatible object + :rtype: dict + """ + serializer = Serializer(self._infer_class_models()) + return serializer._serialize( # type: ignore # pylint: disable=protected-access + self, key_transformer=key_transformer, keep_readonly=keep_readonly, **kwargs + ) + + @classmethod + def _infer_class_models(cls): + try: + str_models = cls.__module__.rsplit(".", 1)[0] + models = sys.modules[str_models] + client_models = {k: v for k, v in models.__dict__.items() if isinstance(v, type)} + if cls.__name__ not in client_models: + raise ValueError("Not Autorest generated code") + except Exception: # pylint: disable=broad-exception-caught + # Assume it's not Autorest generated (tests?). Add ourselves as dependencies. + client_models = {cls.__name__: cls} + return client_models + + @classmethod + def deserialize(cls, data: Any, content_type: Optional[str] = None) -> Self: + """Parse a str using the RestAPI syntax and return a model. + + :param str data: A str using RestAPI structure. JSON by default. + :param str content_type: JSON by default, set application/xml if XML. + :returns: An instance of this model + :raises DeserializationError: if something went wrong + :rtype: Self + """ + deserializer = Deserializer(cls._infer_class_models()) + return deserializer(cls.__name__, data, content_type=content_type) # type: ignore + + @classmethod + def from_dict( + cls, + data: Any, + key_extractors: Optional[Callable[[str, dict[str, Any], Any], Any]] = None, + content_type: Optional[str] = None, + ) -> Self: + """Parse a dict using given key extractor return a model. + + By default consider key + extractors (rest_key_case_insensitive_extractor, attribute_key_case_insensitive_extractor + and last_rest_key_case_insensitive_extractor) + + :param dict data: A dict using RestAPI structure + :param function key_extractors: A key extractor function. + :param str content_type: JSON by default, set application/xml if XML. + :returns: An instance of this model + :raises DeserializationError: if something went wrong + :rtype: Self + """ + deserializer = Deserializer(cls._infer_class_models()) + deserializer.key_extractors = ( # type: ignore + [ # type: ignore + attribute_key_case_insensitive_extractor, + rest_key_case_insensitive_extractor, + last_rest_key_case_insensitive_extractor, + ] + if key_extractors is None + else key_extractors + ) + return deserializer(cls.__name__, data, content_type=content_type) # type: ignore + + @classmethod + def _flatten_subtype(cls, key, objects): + if "_subtype_map" not in cls.__dict__: + return {} + result = dict(cls._subtype_map[key]) + for valuetype in cls._subtype_map[key].values(): + result |= objects[valuetype]._flatten_subtype(key, objects) # pylint: disable=protected-access + return result + + @classmethod + def _classify(cls, response, objects): + """Check the class _subtype_map for any child classes. + We want to ignore any inherited _subtype_maps. + + :param dict response: The initial data + :param dict objects: The class objects + :returns: The class to be used + :rtype: class + """ + for subtype_key in cls.__dict__.get("_subtype_map", {}).keys(): + subtype_value = None + + if not isinstance(response, ET.Element): + rest_api_response_key = cls._get_rest_key_parts(subtype_key)[-1] + subtype_value = response.get(rest_api_response_key, None) or response.get(subtype_key, None) + else: + subtype_value = xml_key_extractor(subtype_key, cls._attribute_map[subtype_key], response) + if subtype_value: + # Try to match base class. Can be class name only + # (bug to fix in Autorest to support x-ms-discriminator-name) + if cls.__name__ == subtype_value: + return cls + flatten_mapping_type = cls._flatten_subtype(subtype_key, objects) + try: + return objects[flatten_mapping_type[subtype_value]] # type: ignore + except KeyError: + _LOGGER.warning( + "Subtype value %s has no mapping, use base class %s.", + subtype_value, + cls.__name__, + ) + break + else: + _LOGGER.warning("Discriminator %s is absent or null, use base class %s.", subtype_key, cls.__name__) + break + return cls + + @classmethod + def _get_rest_key_parts(cls, attr_key): + """Get the RestAPI key of this attr, split it and decode part + :param str attr_key: Attribute key must be in attribute_map. + :returns: A list of RestAPI part + :rtype: list + """ + rest_split_key = _FLATTEN.split(cls._attribute_map[attr_key]["key"]) + return [_decode_attribute_map_key(key_part) for key_part in rest_split_key] + + +def _decode_attribute_map_key(key): + """This decode a key in an _attribute_map to the actual key we want to look at + inside the received data. + + :param str key: A key string from the generated code + :returns: The decoded key + :rtype: str + """ + return key.replace("\\.", ".") + + +class Serializer: # pylint: disable=too-many-public-methods + """Request object model serializer.""" + + basic_types = {str: "str", int: "int", bool: "bool", float: "float"} + + _xml_basic_types_serializers = {"bool": lambda x: str(x).lower()} + days = {0: "Mon", 1: "Tue", 2: "Wed", 3: "Thu", 4: "Fri", 5: "Sat", 6: "Sun"} + months = { + 1: "Jan", + 2: "Feb", + 3: "Mar", + 4: "Apr", + 5: "May", + 6: "Jun", + 7: "Jul", + 8: "Aug", + 9: "Sep", + 10: "Oct", + 11: "Nov", + 12: "Dec", + } + validation = { + "min_length": lambda x, y: len(x) < y, + "max_length": lambda x, y: len(x) > y, + "minimum": lambda x, y: x < y, + "maximum": lambda x, y: x > y, + "minimum_ex": lambda x, y: x <= y, + "maximum_ex": lambda x, y: x >= y, + "min_items": lambda x, y: len(x) < y, + "max_items": lambda x, y: len(x) > y, + "pattern": lambda x, y: not re.match(y, x, re.UNICODE), + "unique": lambda x, y: len(x) != len(set(x)), + "multiple": lambda x, y: x % y != 0, + } + + def __init__(self, classes: Optional[Mapping[str, type]] = None) -> None: + self.serialize_type = { + "iso-8601": Serializer.serialize_iso, + "rfc-1123": Serializer.serialize_rfc, + "unix-time": Serializer.serialize_unix, + "duration": Serializer.serialize_duration, + "date": Serializer.serialize_date, + "time": Serializer.serialize_time, + "decimal": Serializer.serialize_decimal, + "long": Serializer.serialize_long, + "bytearray": Serializer.serialize_bytearray, + "base64": Serializer.serialize_base64, + "object": self.serialize_object, + "[]": self.serialize_iter, + "{}": self.serialize_dict, + } + self.dependencies: dict[str, type] = dict(classes) if classes else {} + self.key_transformer = full_restapi_key_transformer + self.client_side_validation = True + + def _serialize( # pylint: disable=too-many-nested-blocks, too-many-branches, too-many-statements, too-many-locals + self, target_obj, data_type=None, **kwargs + ): + """Serialize data into a string according to type. + + :param object target_obj: The data to be serialized. + :param str data_type: The type to be serialized from. + :rtype: str, dict + :raises SerializationError: if serialization fails. + :returns: The serialized data. + """ + key_transformer = kwargs.get("key_transformer", self.key_transformer) + keep_readonly = kwargs.get("keep_readonly", False) + if target_obj is None: + return None + + attr_name = None + class_name = target_obj.__class__.__name__ + + if data_type: + return self.serialize_data(target_obj, data_type, **kwargs) + + if not hasattr(target_obj, "_attribute_map"): + data_type = type(target_obj).__name__ + if data_type in self.basic_types.values(): + return self.serialize_data(target_obj, data_type, **kwargs) + + # Force "is_xml" kwargs if we detect a XML model + try: + is_xml_model_serialization = kwargs["is_xml"] + except KeyError: + is_xml_model_serialization = kwargs.setdefault("is_xml", target_obj.is_xml_model()) + + serialized = {} + if is_xml_model_serialization: + serialized = target_obj._create_xml_node() # pylint: disable=protected-access + try: + attributes = target_obj._attribute_map # pylint: disable=protected-access + for attr, attr_desc in attributes.items(): + attr_name = attr + if not keep_readonly and target_obj._validation.get( # pylint: disable=protected-access + attr_name, {} + ).get("readonly", False): + continue + + if attr_name == "additional_properties" and attr_desc["key"] == "": + if target_obj.additional_properties is not None: + serialized |= target_obj.additional_properties + continue + try: + + orig_attr = getattr(target_obj, attr) + if is_xml_model_serialization: + pass # Don't provide "transformer" for XML for now. Keep "orig_attr" + else: # JSON + keys, orig_attr = key_transformer(attr, attr_desc.copy(), orig_attr) + keys = keys if isinstance(keys, list) else [keys] + + kwargs["serialization_ctxt"] = attr_desc + new_attr = self.serialize_data(orig_attr, attr_desc["type"], **kwargs) + + if is_xml_model_serialization: + xml_desc = attr_desc.get("xml", {}) + xml_name = xml_desc.get("name", attr_desc["key"]) + xml_prefix = xml_desc.get("prefix", None) + xml_ns = xml_desc.get("ns", None) + if xml_desc.get("attr", False): + if xml_ns: + ET.register_namespace(xml_prefix, xml_ns) + xml_name = "{{{}}}{}".format(xml_ns, xml_name) + serialized.set(xml_name, new_attr) # type: ignore + continue + if xml_desc.get("text", False): + serialized.text = new_attr # type: ignore + continue + if isinstance(new_attr, list): + serialized.extend(new_attr) # type: ignore + elif isinstance(new_attr, ET.Element): + # If the down XML has no XML/Name, + # we MUST replace the tag with the local tag. But keeping the namespaces. + if "name" not in getattr(orig_attr, "_xml_map", {}): + splitted_tag = new_attr.tag.split("}") + if len(splitted_tag) == 2: # Namespace + new_attr.tag = "}".join([splitted_tag[0], xml_name]) + else: + new_attr.tag = xml_name + serialized.append(new_attr) # type: ignore + else: # That's a basic type + # Integrate namespace if necessary + local_node = _create_xml_node(xml_name, xml_prefix, xml_ns) + local_node.text = str(new_attr) + serialized.append(local_node) # type: ignore + else: # JSON + for k in reversed(keys): # type: ignore + new_attr = {k: new_attr} + + _new_attr = new_attr + _serialized = serialized + for k in keys: # type: ignore + if k not in _serialized: + _serialized.update(_new_attr) # type: ignore + _new_attr = _new_attr[k] # type: ignore + _serialized = _serialized[k] + except ValueError as err: + if isinstance(err, SerializationError): + raise + + except (AttributeError, KeyError, TypeError) as err: + msg = "Attribute {} in object {} cannot be serialized.\n{}".format(attr_name, class_name, str(target_obj)) + raise SerializationError(msg) from err + return serialized + + def body(self, data, data_type, **kwargs): + """Serialize data intended for a request body. + + :param object data: The data to be serialized. + :param str data_type: The type to be serialized from. + :rtype: dict + :raises SerializationError: if serialization fails. + :raises ValueError: if data is None + :returns: The serialized request body + """ + + # Just in case this is a dict + internal_data_type_str = data_type.strip("[]{}") + internal_data_type = self.dependencies.get(internal_data_type_str, None) + try: + is_xml_model_serialization = kwargs["is_xml"] + except KeyError: + if internal_data_type and issubclass(internal_data_type, Model): + is_xml_model_serialization = kwargs.setdefault("is_xml", internal_data_type.is_xml_model()) + else: + is_xml_model_serialization = False + if internal_data_type and not isinstance(internal_data_type, Enum): + try: + deserializer = Deserializer(self.dependencies) + # Since it's on serialization, it's almost sure that format is not JSON REST + # We're not able to deal with additional properties for now. + deserializer.additional_properties_detection = False + if is_xml_model_serialization: + deserializer.key_extractors = [ # type: ignore + attribute_key_case_insensitive_extractor, + ] + else: + deserializer.key_extractors = [ + rest_key_case_insensitive_extractor, + attribute_key_case_insensitive_extractor, + last_rest_key_case_insensitive_extractor, + ] + data = deserializer._deserialize(data_type, data) # pylint: disable=protected-access + except DeserializationError as err: + raise SerializationError("Unable to build a model: " + str(err)) from err + + return self._serialize(data, data_type, **kwargs) + + def url(self, name, data, data_type, **kwargs): + """Serialize data intended for a URL path. + + :param str name: The name of the URL path parameter. + :param object data: The data to be serialized. + :param str data_type: The type to be serialized from. + :rtype: str + :returns: The serialized URL path + :raises TypeError: if serialization fails. + :raises ValueError: if data is None + """ + try: + output = self.serialize_data(data, data_type, **kwargs) + if data_type == "bool": + output = json.dumps(output) + + if kwargs.get("skip_quote") is True: + output = str(output) + output = output.replace("{", quote("{")).replace("}", quote("}")) + else: + output = quote(str(output), safe="") + except SerializationError as exc: + raise TypeError("{} must be type {}.".format(name, data_type)) from exc + return output + + def query(self, name, data, data_type, **kwargs): + """Serialize data intended for a URL query. + + :param str name: The name of the query parameter. + :param object data: The data to be serialized. + :param str data_type: The type to be serialized from. + :rtype: str, list + :raises TypeError: if serialization fails. + :raises ValueError: if data is None + :returns: The serialized query parameter + """ + try: + # Treat the list aside, since we don't want to encode the div separator + if data_type.startswith("["): + internal_data_type = data_type[1:-1] + do_quote = not kwargs.get("skip_quote", False) + return self.serialize_iter(data, internal_data_type, do_quote=do_quote, **kwargs) + + # Not a list, regular serialization + output = self.serialize_data(data, data_type, **kwargs) + if data_type == "bool": + output = json.dumps(output) + if kwargs.get("skip_quote") is True: + output = str(output) + else: + output = quote(str(output), safe="") + except SerializationError as exc: + raise TypeError("{} must be type {}.".format(name, data_type)) from exc + return str(output) + + def header(self, name, data, data_type, **kwargs): + """Serialize data intended for a request header. + + :param str name: The name of the header. + :param object data: The data to be serialized. + :param str data_type: The type to be serialized from. + :rtype: str + :raises TypeError: if serialization fails. + :raises ValueError: if data is None + :returns: The serialized header + """ + try: + if data_type in ["[str]"]: + data = ["" if d is None else d for d in data] + + output = self.serialize_data(data, data_type, **kwargs) + if data_type == "bool": + output = json.dumps(output) + except SerializationError as exc: + raise TypeError("{} must be type {}.".format(name, data_type)) from exc + return str(output) + + def serialize_data(self, data, data_type, **kwargs): + """Serialize generic data according to supplied data type. + + :param object data: The data to be serialized. + :param str data_type: The type to be serialized from. + :raises AttributeError: if required data is None. + :raises ValueError: if data is None + :raises SerializationError: if serialization fails. + :returns: The serialized data. + :rtype: str, int, float, bool, dict, list + """ + if data is None: + raise ValueError("No value for given attribute") + + try: + if data is CoreNull: + return None + if data_type in self.basic_types.values(): + return self.serialize_basic(data, data_type, **kwargs) + + if data_type in self.serialize_type: + return self.serialize_type[data_type](data, **kwargs) + + # If dependencies is empty, try with current data class + # It has to be a subclass of Enum anyway + enum_type = self.dependencies.get(data_type, cast(type, data.__class__)) + if issubclass(enum_type, Enum): + return Serializer.serialize_enum(data, enum_obj=enum_type) + + iter_type = data_type[0] + data_type[-1] + if iter_type in self.serialize_type: + return self.serialize_type[iter_type](data, data_type[1:-1], **kwargs) + + except (ValueError, TypeError) as err: + msg = "Unable to serialize value: {!r} as type: {!r}." + raise SerializationError(msg.format(data, data_type)) from err + return self._serialize(data, **kwargs) + + @classmethod + def _get_custom_serializers(cls, data_type, **kwargs): # pylint: disable=inconsistent-return-statements + custom_serializer = kwargs.get("basic_types_serializers", {}).get(data_type) + if custom_serializer: + return custom_serializer + if kwargs.get("is_xml", False): + return cls._xml_basic_types_serializers.get(data_type) + + @classmethod + def serialize_basic(cls, data, data_type, **kwargs): + """Serialize basic builting data type. + Serializes objects to str, int, float or bool. + + Possible kwargs: + - basic_types_serializers dict[str, callable] : If set, use the callable as serializer + - is_xml bool : If set, use xml_basic_types_serializers + + :param obj data: Object to be serialized. + :param str data_type: Type of object in the iterable. + :rtype: str, int, float, bool + :return: serialized object + :raises TypeError: raise if data_type is not one of str, int, float, bool. + """ + custom_serializer = cls._get_custom_serializers(data_type, **kwargs) + if custom_serializer: + return custom_serializer(data) + if data_type == "str": + return cls.serialize_unicode(data) + if data_type == "int": + return int(data) + if data_type == "float": + return float(data) + if data_type == "bool": + return bool(data) + raise TypeError("Unknown basic data type: {}".format(data_type)) + + @classmethod + def serialize_unicode(cls, data): + """Special handling for serializing unicode strings in Py2. + Encode to UTF-8 if unicode, otherwise handle as a str. + + :param str data: Object to be serialized. + :rtype: str + :return: serialized object + """ + try: # If I received an enum, return its value + return data.value + except AttributeError: + pass + + try: + if isinstance(data, unicode): # type: ignore + # Don't change it, JSON and XML ElementTree are totally able + # to serialize correctly u'' strings + return data + except NameError: + return str(data) + return str(data) + + def serialize_iter(self, data, iter_type, div=None, **kwargs): + """Serialize iterable. + + Supported kwargs: + - serialization_ctxt dict : The current entry of _attribute_map, or same format. + serialization_ctxt['type'] should be same as data_type. + - is_xml bool : If set, serialize as XML + + :param list data: Object to be serialized. + :param str iter_type: Type of object in the iterable. + :param str div: If set, this str will be used to combine the elements + in the iterable into a combined string. Default is 'None'. + Defaults to False. + :rtype: list, str + :return: serialized iterable + """ + if isinstance(data, str): + raise SerializationError("Refuse str type as a valid iter type.") + + serialization_ctxt = kwargs.get("serialization_ctxt", {}) + is_xml = kwargs.get("is_xml", False) + + serialized = [] + for d in data: + try: + serialized.append(self.serialize_data(d, iter_type, **kwargs)) + except ValueError as err: + if isinstance(err, SerializationError): + raise + serialized.append(None) + + if kwargs.get("do_quote", False): + serialized = ["" if s is None else quote(str(s), safe="") for s in serialized] + + if div: + serialized = ["" if s is None else str(s) for s in serialized] + serialized = div.join(serialized) + + if "xml" in serialization_ctxt or is_xml: + # XML serialization is more complicated + xml_desc = serialization_ctxt.get("xml", {}) + xml_name = xml_desc.get("name") + if not xml_name: + xml_name = serialization_ctxt["key"] + + # Create a wrap node if necessary (use the fact that Element and list have "append") + is_wrapped = xml_desc.get("wrapped", False) + node_name = xml_desc.get("itemsName", xml_name) + if is_wrapped: + final_result = _create_xml_node(xml_name, xml_desc.get("prefix", None), xml_desc.get("ns", None)) + else: + final_result = [] + # All list elements to "local_node" + for el in serialized: + if isinstance(el, ET.Element): + el_node = el + else: + el_node = _create_xml_node(node_name, xml_desc.get("prefix", None), xml_desc.get("ns", None)) + if el is not None: # Otherwise it writes "None" :-p + el_node.text = str(el) + final_result.append(el_node) + return final_result + return serialized + + def serialize_dict(self, attr, dict_type, **kwargs): + """Serialize a dictionary of objects. + + :param dict attr: Object to be serialized. + :param str dict_type: Type of object in the dictionary. + :rtype: dict + :return: serialized dictionary + """ + serialization_ctxt = kwargs.get("serialization_ctxt", {}) + serialized = {} + for key, value in attr.items(): + try: + serialized[self.serialize_unicode(key)] = self.serialize_data(value, dict_type, **kwargs) + except ValueError as err: + if isinstance(err, SerializationError): + raise + serialized[self.serialize_unicode(key)] = None + + if "xml" in serialization_ctxt: + # XML serialization is more complicated + xml_desc = serialization_ctxt["xml"] + xml_name = xml_desc["name"] + + final_result = _create_xml_node(xml_name, xml_desc.get("prefix", None), xml_desc.get("ns", None)) + for key, value in serialized.items(): + ET.SubElement(final_result, key).text = value + return final_result + + return serialized + + def serialize_object(self, attr, **kwargs): # pylint: disable=too-many-return-statements + """Serialize a generic object. + This will be handled as a dictionary. If object passed in is not + a basic type (str, int, float, dict, list) it will simply be + cast to str. + + :param dict attr: Object to be serialized. + :rtype: dict or str + :return: serialized object + """ + if attr is None: + return None + if isinstance(attr, ET.Element): + return attr + obj_type = type(attr) + if obj_type in self.basic_types: + return self.serialize_basic(attr, self.basic_types[obj_type], **kwargs) + if obj_type is _long_type: + return self.serialize_long(attr) + if obj_type is str: + return self.serialize_unicode(attr) + if obj_type is datetime.datetime: + return self.serialize_iso(attr) + if obj_type is datetime.date: + return self.serialize_date(attr) + if obj_type is datetime.time: + return self.serialize_time(attr) + if obj_type is datetime.timedelta: + return self.serialize_duration(attr) + if obj_type is decimal.Decimal: + return self.serialize_decimal(attr) + + # If it's a model or I know this dependency, serialize as a Model + if obj_type in self.dependencies.values() or isinstance(attr, Model): + return self._serialize(attr) + + if obj_type == dict: + serialized = {} + for key, value in attr.items(): + try: + serialized[self.serialize_unicode(key)] = self.serialize_object(value, **kwargs) + except ValueError: + serialized[self.serialize_unicode(key)] = None + return serialized + + if obj_type == list: + serialized = [] + for obj in attr: + try: + serialized.append(self.serialize_object(obj, **kwargs)) + except ValueError: + pass + return serialized + return str(attr) + + @staticmethod + def serialize_enum(attr, enum_obj=None): + try: + result = attr.value + except AttributeError: + result = attr + try: + enum_obj(result) # type: ignore + return result + except ValueError as exc: + for enum_value in enum_obj: # type: ignore + if enum_value.value.lower() == str(attr).lower(): + return enum_value.value + error = "{!r} is not valid value for enum {!r}" + raise SerializationError(error.format(attr, enum_obj)) from exc + + @staticmethod + def serialize_bytearray(attr, **kwargs): # pylint: disable=unused-argument + """Serialize bytearray into base-64 string. + + :param str attr: Object to be serialized. + :rtype: str + :return: serialized base64 + """ + return b64encode(attr).decode() + + @staticmethod + def serialize_base64(attr, **kwargs): # pylint: disable=unused-argument + """Serialize str into base-64 string. + + :param str attr: Object to be serialized. + :rtype: str + :return: serialized base64 + """ + encoded = b64encode(attr).decode("ascii") + return encoded.strip("=").replace("+", "-").replace("/", "_") + + @staticmethod + def serialize_decimal(attr, **kwargs): # pylint: disable=unused-argument + """Serialize Decimal object to float. + + :param decimal attr: Object to be serialized. + :rtype: float + :return: serialized decimal + """ + return float(attr) + + @staticmethod + def serialize_long(attr, **kwargs): # pylint: disable=unused-argument + """Serialize long (Py2) or int (Py3). + + :param int attr: Object to be serialized. + :rtype: int/long + :return: serialized long + """ + return _long_type(attr) + + @staticmethod + def serialize_date(attr, **kwargs): # pylint: disable=unused-argument + """Serialize Date object into ISO-8601 formatted string. + + :param Date attr: Object to be serialized. + :rtype: str + :return: serialized date + """ + if isinstance(attr, str): + attr = isodate.parse_date(attr) + t = "{:04}-{:02}-{:02}".format(attr.year, attr.month, attr.day) + return t + + @staticmethod + def serialize_time(attr, **kwargs): # pylint: disable=unused-argument + """Serialize Time object into ISO-8601 formatted string. + + :param datetime.time attr: Object to be serialized. + :rtype: str + :return: serialized time + """ + if isinstance(attr, str): + attr = isodate.parse_time(attr) + t = "{:02}:{:02}:{:02}".format(attr.hour, attr.minute, attr.second) + if attr.microsecond: + t += ".{:02}".format(attr.microsecond) + return t + + @staticmethod + def serialize_duration(attr, **kwargs): # pylint: disable=unused-argument + """Serialize TimeDelta object into ISO-8601 formatted string. + + :param TimeDelta attr: Object to be serialized. + :rtype: str + :return: serialized duration + """ + if isinstance(attr, str): + attr = isodate.parse_duration(attr) + return isodate.duration_isoformat(attr) + + @staticmethod + def serialize_rfc(attr, **kwargs): # pylint: disable=unused-argument + """Serialize Datetime object into RFC-1123 formatted string. + + :param Datetime attr: Object to be serialized. + :rtype: str + :raises TypeError: if format invalid. + :return: serialized rfc + """ + try: + if not attr.tzinfo: + _LOGGER.warning("Datetime with no tzinfo will be considered UTC.") + utc = attr.utctimetuple() + except AttributeError as exc: + raise TypeError("RFC1123 object must be valid Datetime object.") from exc + + return "{}, {:02} {} {:04} {:02}:{:02}:{:02} GMT".format( + Serializer.days[utc.tm_wday], + utc.tm_mday, + Serializer.months[utc.tm_mon], + utc.tm_year, + utc.tm_hour, + utc.tm_min, + utc.tm_sec, + ) + + @staticmethod + def serialize_iso(attr, **kwargs): # pylint: disable=unused-argument + """Serialize Datetime object into ISO-8601 formatted string. + + :param Datetime attr: Object to be serialized. + :rtype: str + :raises SerializationError: if format invalid. + :return: serialized iso + """ + if isinstance(attr, str): + attr = isodate.parse_datetime(attr) + try: + if not attr.tzinfo: + _LOGGER.warning("Datetime with no tzinfo will be considered UTC.") + utc = attr.utctimetuple() + if utc.tm_year > 9999 or utc.tm_year < 1: + raise OverflowError("Hit max or min date") + + microseconds = str(attr.microsecond).rjust(6, "0").rstrip("0").ljust(3, "0") + if microseconds: + microseconds = "." + microseconds + date = "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}".format( + utc.tm_year, utc.tm_mon, utc.tm_mday, utc.tm_hour, utc.tm_min, utc.tm_sec + ) + return date + microseconds + "Z" + except (ValueError, OverflowError) as err: + msg = "Unable to serialize datetime object." + raise SerializationError(msg) from err + except AttributeError as err: + msg = "ISO-8601 object must be valid Datetime object." + raise TypeError(msg) from err + + @staticmethod + def serialize_unix(attr, **kwargs): # pylint: disable=unused-argument + """Serialize Datetime object into IntTime format. + This is represented as seconds. + + :param Datetime attr: Object to be serialized. + :rtype: int + :raises SerializationError: if format invalid + :return: serialied unix + """ + if isinstance(attr, int): + return attr + try: + if not attr.tzinfo: + _LOGGER.warning("Datetime with no tzinfo will be considered UTC.") + return int(calendar.timegm(attr.utctimetuple())) + except AttributeError as exc: + raise TypeError("Unix time object must be valid Datetime object.") from exc + + +def rest_key_extractor(attr, attr_desc, data): # pylint: disable=unused-argument + key = attr_desc["key"] + working_data = data + + while "." in key: + # Need the cast, as for some reasons "split" is typed as list[str | Any] + dict_keys = cast(list[str], _FLATTEN.split(key)) + if len(dict_keys) == 1: + key = _decode_attribute_map_key(dict_keys[0]) + break + working_key = _decode_attribute_map_key(dict_keys[0]) + working_data = working_data.get(working_key, data) + if working_data is None: + # If at any point while following flatten JSON path see None, it means + # that all properties under are None as well + return None + key = ".".join(dict_keys[1:]) + + return working_data.get(key) + + +def rest_key_case_insensitive_extractor( # pylint: disable=unused-argument, inconsistent-return-statements + attr, attr_desc, data +): + key = attr_desc["key"] + working_data = data + + while "." in key: + dict_keys = _FLATTEN.split(key) + if len(dict_keys) == 1: + key = _decode_attribute_map_key(dict_keys[0]) + break + working_key = _decode_attribute_map_key(dict_keys[0]) + working_data = attribute_key_case_insensitive_extractor(working_key, None, working_data) + if working_data is None: + # If at any point while following flatten JSON path see None, it means + # that all properties under are None as well + return None + key = ".".join(dict_keys[1:]) + + if working_data: + return attribute_key_case_insensitive_extractor(key, None, working_data) + + +def last_rest_key_extractor(attr, attr_desc, data): # pylint: disable=unused-argument + """Extract the attribute in "data" based on the last part of the JSON path key. + + :param str attr: The attribute to extract + :param dict attr_desc: The attribute description + :param dict data: The data to extract from + :rtype: object + :returns: The extracted attribute + """ + key = attr_desc["key"] + dict_keys = _FLATTEN.split(key) + return attribute_key_extractor(dict_keys[-1], None, data) + + +def last_rest_key_case_insensitive_extractor(attr, attr_desc, data): # pylint: disable=unused-argument + """Extract the attribute in "data" based on the last part of the JSON path key. + + This is the case insensitive version of "last_rest_key_extractor" + :param str attr: The attribute to extract + :param dict attr_desc: The attribute description + :param dict data: The data to extract from + :rtype: object + :returns: The extracted attribute + """ + key = attr_desc["key"] + dict_keys = _FLATTEN.split(key) + return attribute_key_case_insensitive_extractor(dict_keys[-1], None, data) + + +def attribute_key_extractor(attr, _, data): + return data.get(attr) + + +def attribute_key_case_insensitive_extractor(attr, _, data): + found_key = None + lower_attr = attr.lower() + for key in data: + if lower_attr == key.lower(): + found_key = key + break + + return data.get(found_key) + + +def _extract_name_from_internal_type(internal_type): + """Given an internal type XML description, extract correct XML name with namespace. + + :param dict internal_type: An model type + :rtype: tuple + :returns: A tuple XML name + namespace dict + """ + internal_type_xml_map = getattr(internal_type, "_xml_map", {}) + xml_name = internal_type_xml_map.get("name", internal_type.__name__) + xml_ns = internal_type_xml_map.get("ns", None) + if xml_ns: + xml_name = "{{{}}}{}".format(xml_ns, xml_name) + return xml_name + + +def xml_key_extractor(attr, attr_desc, data): # pylint: disable=unused-argument,too-many-return-statements + if isinstance(data, dict): + return None + + # Test if this model is XML ready first + if not isinstance(data, ET.Element): + return None + + xml_desc = attr_desc.get("xml", {}) + xml_name = xml_desc.get("name", attr_desc["key"]) + + # Look for a children + is_iter_type = attr_desc["type"].startswith("[") + is_wrapped = xml_desc.get("wrapped", False) + internal_type = attr_desc.get("internalType", None) + internal_type_xml_map = getattr(internal_type, "_xml_map", {}) + + # Integrate namespace if necessary + xml_ns = xml_desc.get("ns", internal_type_xml_map.get("ns", None)) + if xml_ns: + xml_name = "{{{}}}{}".format(xml_ns, xml_name) + + # If it's an attribute, that's simple + if xml_desc.get("attr", False): + return data.get(xml_name) + + # If it's x-ms-text, that's simple too + if xml_desc.get("text", False): + return data.text + + # Scenario where I take the local name: + # - Wrapped node + # - Internal type is an enum (considered basic types) + # - Internal type has no XML/Name node + if is_wrapped or (internal_type and (issubclass(internal_type, Enum) or "name" not in internal_type_xml_map)): + children = data.findall(xml_name) + # If internal type has a local name and it's not a list, I use that name + elif not is_iter_type and internal_type and "name" in internal_type_xml_map: + xml_name = _extract_name_from_internal_type(internal_type) + children = data.findall(xml_name) + # That's an array + else: + if internal_type: # Complex type, ignore itemsName and use the complex type name + items_name = _extract_name_from_internal_type(internal_type) + else: + items_name = xml_desc.get("itemsName", xml_name) + children = data.findall(items_name) + + if len(children) == 0: + if is_iter_type: + if is_wrapped: + return None # is_wrapped no node, we want None + return [] # not wrapped, assume empty list + return None # Assume it's not there, maybe an optional node. + + # If is_iter_type and not wrapped, return all found children + if is_iter_type: + if not is_wrapped: + return children + # Iter and wrapped, should have found one node only (the wrap one) + if len(children) != 1: + raise DeserializationError( + "Tried to deserialize an array not wrapped, and found several nodes '{}'. Maybe you should declare this array as wrapped?".format( + xml_name + ) + ) + return list(children[0]) # Might be empty list and that's ok. + + # Here it's not a itertype, we should have found one element only or empty + if len(children) > 1: + raise DeserializationError("Find several XML '{}' where it was not expected".format(xml_name)) + return children[0] + + +class Deserializer: + """Response object model deserializer. + + :param dict classes: Class type dictionary for deserializing complex types. + :ivar list key_extractors: Ordered list of extractors to be used by this deserializer. + """ + + basic_types = {str: "str", int: "int", bool: "bool", float: "float"} + + valid_date = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?") + + def __init__(self, classes: Optional[Mapping[str, type]] = None) -> None: + self.deserialize_type = { + "iso-8601": Deserializer.deserialize_iso, + "rfc-1123": Deserializer.deserialize_rfc, + "unix-time": Deserializer.deserialize_unix, + "duration": Deserializer.deserialize_duration, + "date": Deserializer.deserialize_date, + "time": Deserializer.deserialize_time, + "decimal": Deserializer.deserialize_decimal, + "long": Deserializer.deserialize_long, + "bytearray": Deserializer.deserialize_bytearray, + "base64": Deserializer.deserialize_base64, + "object": self.deserialize_object, + "[]": self.deserialize_iter, + "{}": self.deserialize_dict, + } + self.deserialize_expected_types = { + "duration": (isodate.Duration, datetime.timedelta), + "iso-8601": (datetime.datetime), + } + self.dependencies: dict[str, type] = dict(classes) if classes else {} + self.key_extractors = [rest_key_extractor, xml_key_extractor] + # Additional properties only works if the "rest_key_extractor" is used to + # extract the keys. Making it to work whatever the key extractor is too much + # complicated, with no real scenario for now. + # So adding a flag to disable additional properties detection. This flag should be + # used if your expect the deserialization to NOT come from a JSON REST syntax. + # Otherwise, result are unexpected + self.additional_properties_detection = True + + def __call__(self, target_obj, response_data, content_type=None): + """Call the deserializer to process a REST response. + + :param str target_obj: Target data type to deserialize to. + :param requests.Response response_data: REST response object. + :param str content_type: Swagger "produces" if available. + :raises DeserializationError: if deserialization fails. + :return: Deserialized object. + :rtype: object + """ + data = self._unpack_content(response_data, content_type) + return self._deserialize(target_obj, data) + + def _deserialize(self, target_obj, data): # pylint: disable=inconsistent-return-statements + """Call the deserializer on a model. + + Data needs to be already deserialized as JSON or XML ElementTree + + :param str target_obj: Target data type to deserialize to. + :param object data: Object to deserialize. + :raises DeserializationError: if deserialization fails. + :return: Deserialized object. + :rtype: object + """ + # This is already a model, go recursive just in case + if hasattr(data, "_attribute_map"): + constants = [name for name, config in getattr(data, "_validation", {}).items() if config.get("constant")] + try: + for attr, mapconfig in data._attribute_map.items(): # pylint: disable=protected-access + if attr in constants: + continue + value = getattr(data, attr) + if value is None: + continue + local_type = mapconfig["type"] + internal_data_type = local_type.strip("[]{}") + if internal_data_type not in self.dependencies or isinstance(internal_data_type, Enum): + continue + setattr(data, attr, self._deserialize(local_type, value)) + return data + except AttributeError: + return + + response, class_name = self._classify_target(target_obj, data) + + if isinstance(response, str): + return self.deserialize_data(data, response) + if isinstance(response, type) and issubclass(response, Enum): + return self.deserialize_enum(data, response) + + if data is None or data is CoreNull: + return data + try: + attributes = response._attribute_map # type: ignore # pylint: disable=protected-access + d_attrs = {} + for attr, attr_desc in attributes.items(): + # Check empty string. If it's not empty, someone has a real "additionalProperties"... + if attr == "additional_properties" and attr_desc["key"] == "": + continue + raw_value = None + # Enhance attr_desc with some dynamic data + attr_desc = attr_desc.copy() # Do a copy, do not change the real one + internal_data_type = attr_desc["type"].strip("[]{}") + if internal_data_type in self.dependencies: + attr_desc["internalType"] = self.dependencies[internal_data_type] + + for key_extractor in self.key_extractors: + found_value = key_extractor(attr, attr_desc, data) + if found_value is not None: + if raw_value is not None and raw_value != found_value: + msg = ( + "Ignoring extracted value '%s' from %s for key '%s'" + " (duplicate extraction, follow extractors order)" + ) + _LOGGER.warning(msg, found_value, key_extractor, attr) + continue + raw_value = found_value + + value = self.deserialize_data(raw_value, attr_desc["type"]) + d_attrs[attr] = value + except (AttributeError, TypeError, KeyError) as err: + msg = "Unable to deserialize to object: " + class_name # type: ignore + raise DeserializationError(msg) from err + additional_properties = self._build_additional_properties(attributes, data) + return self._instantiate_model(response, d_attrs, additional_properties) + + def _build_additional_properties(self, attribute_map, data): + if not self.additional_properties_detection: + return None + if "additional_properties" in attribute_map and attribute_map.get("additional_properties", {}).get("key") != "": + # Check empty string. If it's not empty, someone has a real "additionalProperties" + return None + if isinstance(data, ET.Element): + data = {el.tag: el.text for el in data} + + known_keys = { + _decode_attribute_map_key(_FLATTEN.split(desc["key"])[0]) + for desc in attribute_map.values() + if desc["key"] != "" + } + present_keys = set(data.keys()) + missing_keys = present_keys - known_keys + return {key: data[key] for key in missing_keys} + + def _classify_target(self, target, data): + """Check to see whether the deserialization target object can + be classified into a subclass. + Once classification has been determined, initialize object. + + :param str target: The target object type to deserialize to. + :param str/dict data: The response data to deserialize. + :return: The classified target object and its class name. + :rtype: tuple + """ + if target is None: + return None, None + + if isinstance(target, str): + try: + target = self.dependencies[target] + except KeyError: + return target, target + + try: + target = target._classify(data, self.dependencies) # type: ignore # pylint: disable=protected-access + except AttributeError: + pass # Target is not a Model, no classify + return target, target.__class__.__name__ # type: ignore + + def failsafe_deserialize(self, target_obj, data, content_type=None): + """Ignores any errors encountered in deserialization, + and falls back to not deserializing the object. Recommended + for use in error deserialization, as we want to return the + HttpResponseError to users, and not have them deal with + a deserialization error. + + :param str target_obj: The target object type to deserialize to. + :param str/dict data: The response data to deserialize. + :param str content_type: Swagger "produces" if available. + :return: Deserialized object. + :rtype: object + """ + try: + return self(target_obj, data, content_type=content_type) + except: # pylint: disable=bare-except + _LOGGER.debug( + "Ran into a deserialization error. Ignoring since this is failsafe deserialization", exc_info=True + ) + return None + + @staticmethod + def _unpack_content(raw_data, content_type=None): + """Extract the correct structure for deserialization. + + If raw_data is a PipelineResponse, try to extract the result of RawDeserializer. + if we can't, raise. Your Pipeline should have a RawDeserializer. + + If not a pipeline response and raw_data is bytes or string, use content-type + to decode it. If no content-type, try JSON. + + If raw_data is something else, bypass all logic and return it directly. + + :param obj raw_data: Data to be processed. + :param str content_type: How to parse if raw_data is a string/bytes. + :raises JSONDecodeError: If JSON is requested and parsing is impossible. + :raises UnicodeDecodeError: If bytes is not UTF8 + :rtype: object + :return: Unpacked content. + """ + # Assume this is enough to detect a Pipeline Response without importing it + context = getattr(raw_data, "context", {}) + if context: + if RawDeserializer.CONTEXT_NAME in context: + return context[RawDeserializer.CONTEXT_NAME] + raise ValueError("This pipeline didn't have the RawDeserializer policy; can't deserialize") + + # Assume this is enough to recognize universal_http.ClientResponse without importing it + if hasattr(raw_data, "body"): + return RawDeserializer.deserialize_from_http_generics(raw_data.text(), raw_data.headers) + + # Assume this enough to recognize requests.Response without importing it. + if hasattr(raw_data, "_content_consumed"): + return RawDeserializer.deserialize_from_http_generics(raw_data.text, raw_data.headers) + + if isinstance(raw_data, (str, bytes)) or hasattr(raw_data, "read"): + return RawDeserializer.deserialize_from_text(raw_data, content_type) # type: ignore + return raw_data + + def _instantiate_model(self, response, attrs, additional_properties=None): + """Instantiate a response model passing in deserialized args. + + :param Response response: The response model class. + :param dict attrs: The deserialized response attributes. + :param dict additional_properties: Additional properties to be set. + :rtype: Response + :return: The instantiated response model. + """ + if callable(response): + subtype = getattr(response, "_subtype_map", {}) + try: + readonly = [ + k + for k, v in response._validation.items() # pylint: disable=protected-access # type: ignore + if v.get("readonly") + ] + const = [ + k + for k, v in response._validation.items() # pylint: disable=protected-access # type: ignore + if v.get("constant") + ] + kwargs = {k: v for k, v in attrs.items() if k not in subtype and k not in readonly + const} + response_obj = response(**kwargs) + for attr in readonly: + setattr(response_obj, attr, attrs.get(attr)) + if additional_properties: + response_obj.additional_properties = additional_properties # type: ignore + return response_obj + except TypeError as err: + msg = "Unable to deserialize {} into model {}. ".format(kwargs, response) # type: ignore + raise DeserializationError(msg + str(err)) from err + else: + try: + for attr, value in attrs.items(): + setattr(response, attr, value) + return response + except Exception as exp: + msg = "Unable to populate response model. " + msg += "Type: {}, Error: {}".format(type(response), exp) + raise DeserializationError(msg) from exp + + def deserialize_data(self, data, data_type): # pylint: disable=too-many-return-statements + """Process data for deserialization according to data type. + + :param str data: The response string to be deserialized. + :param str data_type: The type to deserialize to. + :raises DeserializationError: if deserialization fails. + :return: Deserialized object. + :rtype: object + """ + if data is None: + return data + + try: + if not data_type: + return data + if data_type in self.basic_types.values(): + return self.deserialize_basic(data, data_type) + if data_type in self.deserialize_type: + if isinstance(data, self.deserialize_expected_types.get(data_type, tuple())): + return data + + is_a_text_parsing_type = lambda x: x not in [ # pylint: disable=unnecessary-lambda-assignment + "object", + "[]", + r"{}", + ] + if isinstance(data, ET.Element) and is_a_text_parsing_type(data_type) and not data.text: + return None + data_val = self.deserialize_type[data_type](data) + return data_val + + iter_type = data_type[0] + data_type[-1] + if iter_type in self.deserialize_type: + return self.deserialize_type[iter_type](data, data_type[1:-1]) + + obj_type = self.dependencies[data_type] + if issubclass(obj_type, Enum): + if isinstance(data, ET.Element): + data = data.text + return self.deserialize_enum(data, obj_type) + + except (ValueError, TypeError, AttributeError) as err: + msg = "Unable to deserialize response data." + msg += " Data: {}, {}".format(data, data_type) + raise DeserializationError(msg) from err + return self._deserialize(obj_type, data) + + def deserialize_iter(self, attr, iter_type): + """Deserialize an iterable. + + :param list attr: Iterable to be deserialized. + :param str iter_type: The type of object in the iterable. + :return: Deserialized iterable. + :rtype: list + """ + if attr is None: + return None + if isinstance(attr, ET.Element): # If I receive an element here, get the children + attr = list(attr) + if not isinstance(attr, (list, set)): + raise DeserializationError("Cannot deserialize as [{}] an object of type {}".format(iter_type, type(attr))) + return [self.deserialize_data(a, iter_type) for a in attr] + + def deserialize_dict(self, attr, dict_type): + """Deserialize a dictionary. + + :param dict/list attr: Dictionary to be deserialized. Also accepts + a list of key, value pairs. + :param str dict_type: The object type of the items in the dictionary. + :return: Deserialized dictionary. + :rtype: dict + """ + if isinstance(attr, list): + return {x["key"]: self.deserialize_data(x["value"], dict_type) for x in attr} + + if isinstance(attr, ET.Element): + # Transform value into {"Key": "value"} + attr = {el.tag: el.text for el in attr} + return {k: self.deserialize_data(v, dict_type) for k, v in attr.items()} + + def deserialize_object(self, attr, **kwargs): # pylint: disable=too-many-return-statements + """Deserialize a generic object. + This will be handled as a dictionary. + + :param dict attr: Dictionary to be deserialized. + :return: Deserialized object. + :rtype: dict + :raises TypeError: if non-builtin datatype encountered. + """ + if attr is None: + return None + if isinstance(attr, ET.Element): + # Do no recurse on XML, just return the tree as-is + return attr + if isinstance(attr, str): + return self.deserialize_basic(attr, "str") + obj_type = type(attr) + if obj_type in self.basic_types: + return self.deserialize_basic(attr, self.basic_types[obj_type]) + if obj_type is _long_type: + return self.deserialize_long(attr) + + if obj_type == dict: + deserialized = {} + for key, value in attr.items(): + try: + deserialized[key] = self.deserialize_object(value, **kwargs) + except ValueError: + deserialized[key] = None + return deserialized + + if obj_type == list: + deserialized = [] + for obj in attr: + try: + deserialized.append(self.deserialize_object(obj, **kwargs)) + except ValueError: + pass + return deserialized + + error = "Cannot deserialize generic object with type: " + raise TypeError(error + str(obj_type)) + + def deserialize_basic(self, attr, data_type): # pylint: disable=too-many-return-statements + """Deserialize basic builtin data type from string. + Will attempt to convert to str, int, float and bool. + This function will also accept '1', '0', 'true' and 'false' as + valid bool values. + + :param str attr: response string to be deserialized. + :param str data_type: deserialization data type. + :return: Deserialized basic type. + :rtype: str, int, float or bool + :raises TypeError: if string format is not valid or data_type is not one of str, int, float, bool. + """ + # If we're here, data is supposed to be a basic type. + # If it's still an XML node, take the text + if isinstance(attr, ET.Element): + attr = attr.text + if not attr: + if data_type == "str": + # None or '', node is empty string. + return "" + # None or '', node with a strong type is None. + # Don't try to model "empty bool" or "empty int" + return None + + if data_type == "bool": + if attr in [True, False, 1, 0]: + return bool(attr) + if isinstance(attr, str): + if attr.lower() in ["true", "1"]: + return True + if attr.lower() in ["false", "0"]: + return False + raise TypeError("Invalid boolean value: {}".format(attr)) + + if data_type == "str": + return self.deserialize_unicode(attr) + if data_type == "int": + return int(attr) + if data_type == "float": + return float(attr) + raise TypeError("Unknown basic data type: {}".format(data_type)) + + @staticmethod + def deserialize_unicode(data): + """Preserve unicode objects in Python 2, otherwise return data + as a string. + + :param str data: response string to be deserialized. + :return: Deserialized string. + :rtype: str or unicode + """ + # We might be here because we have an enum modeled as string, + # and we try to deserialize a partial dict with enum inside + if isinstance(data, Enum): + return data + + # Consider this is real string + try: + if isinstance(data, unicode): # type: ignore + return data + except NameError: + return str(data) + return str(data) + + @staticmethod + def deserialize_enum(data, enum_obj): + """Deserialize string into enum object. + + If the string is not a valid enum value it will be returned as-is + and a warning will be logged. + + :param str data: Response string to be deserialized. If this value is + None or invalid it will be returned as-is. + :param Enum enum_obj: Enum object to deserialize to. + :return: Deserialized enum object. + :rtype: Enum + """ + if isinstance(data, enum_obj) or data is None: + return data + if isinstance(data, Enum): + data = data.value + if isinstance(data, int): + # Workaround. We might consider remove it in the future. + try: + return list(enum_obj.__members__.values())[data] + except IndexError as exc: + error = "{!r} is not a valid index for enum {!r}" + raise DeserializationError(error.format(data, enum_obj)) from exc + try: + return enum_obj(str(data)) + except ValueError: + for enum_value in enum_obj: + if enum_value.value.lower() == str(data).lower(): + return enum_value + # We don't fail anymore for unknown value, we deserialize as a string + _LOGGER.warning("Deserializer is not able to find %s as valid enum in %s", data, enum_obj) + return Deserializer.deserialize_unicode(data) + + @staticmethod + def deserialize_bytearray(attr): + """Deserialize string into bytearray. + + :param str attr: response string to be deserialized. + :return: Deserialized bytearray + :rtype: bytearray + :raises TypeError: if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + return bytearray(b64decode(attr)) # type: ignore + + @staticmethod + def deserialize_base64(attr): + """Deserialize base64 encoded string into string. + + :param str attr: response string to be deserialized. + :return: Deserialized base64 string + :rtype: bytearray + :raises TypeError: if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + padding = "=" * (3 - (len(attr) + 3) % 4) # type: ignore + attr = attr + padding # type: ignore + encoded = attr.replace("-", "+").replace("_", "/") + return b64decode(encoded) + + @staticmethod + def deserialize_decimal(attr): + """Deserialize string into Decimal object. + + :param str attr: response string to be deserialized. + :return: Deserialized decimal + :raises DeserializationError: if string format invalid. + :rtype: decimal + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + return decimal.Decimal(str(attr)) # type: ignore + except decimal.DecimalException as err: + msg = "Invalid decimal {}".format(attr) + raise DeserializationError(msg) from err + + @staticmethod + def deserialize_long(attr): + """Deserialize string into long (Py2) or int (Py3). + + :param str attr: response string to be deserialized. + :return: Deserialized int + :rtype: long or int + :raises ValueError: if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + return _long_type(attr) # type: ignore + + @staticmethod + def deserialize_duration(attr): + """Deserialize ISO-8601 formatted string into TimeDelta object. + + :param str attr: response string to be deserialized. + :return: Deserialized duration + :rtype: TimeDelta + :raises DeserializationError: if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + duration = isodate.parse_duration(attr) + except (ValueError, OverflowError, AttributeError) as err: + msg = "Cannot deserialize duration object." + raise DeserializationError(msg) from err + return duration + + @staticmethod + def deserialize_date(attr): + """Deserialize ISO-8601 formatted string into Date object. + + :param str attr: response string to be deserialized. + :return: Deserialized date + :rtype: Date + :raises DeserializationError: if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + if re.search(r"[^\W\d_]", attr, re.I + re.U): # type: ignore + raise DeserializationError("Date must have only digits and -. Received: %s" % attr) + # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception. + return isodate.parse_date(attr, defaultmonth=0, defaultday=0) + + @staticmethod + def deserialize_time(attr): + """Deserialize ISO-8601 formatted string into time object. + + :param str attr: response string to be deserialized. + :return: Deserialized time + :rtype: datetime.time + :raises DeserializationError: if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + if re.search(r"[^\W\d_]", attr, re.I + re.U): # type: ignore + raise DeserializationError("Date must have only digits and -. Received: %s" % attr) + return isodate.parse_time(attr) + + @staticmethod + def deserialize_rfc(attr): + """Deserialize RFC-1123 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :return: Deserialized RFC datetime + :rtype: Datetime + :raises DeserializationError: if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + parsed_date = email.utils.parsedate_tz(attr) # type: ignore + date_obj = datetime.datetime( + *parsed_date[:6], tzinfo=datetime.timezone(datetime.timedelta(minutes=(parsed_date[9] or 0) / 60)) + ) + if not date_obj.tzinfo: + date_obj = date_obj.astimezone(tz=TZ_UTC) + except ValueError as err: + msg = "Cannot deserialize to rfc datetime object." + raise DeserializationError(msg) from err + return date_obj + + @staticmethod + def deserialize_iso(attr): + """Deserialize ISO-8601 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :return: Deserialized ISO datetime + :rtype: Datetime + :raises DeserializationError: if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + attr = attr.upper() # type: ignore + match = Deserializer.valid_date.match(attr) + if not match: + raise ValueError("Invalid datetime string: " + attr) + + check_decimal = attr.split(".") + if len(check_decimal) > 1: + decimal_str = "" + for digit in check_decimal[1]: + if digit.isdigit(): + decimal_str += digit + else: + break + if len(decimal_str) > 6: + attr = attr.replace(decimal_str, decimal_str[0:6]) + + date_obj = isodate.parse_datetime(attr) + test_utc = date_obj.utctimetuple() + if test_utc.tm_year > 9999 or test_utc.tm_year < 1: + raise OverflowError("Hit max or min date") + except (ValueError, OverflowError, AttributeError) as err: + msg = "Cannot deserialize datetime object." + raise DeserializationError(msg) from err + return date_obj + + @staticmethod + def deserialize_unix(attr): + """Serialize Datetime object into IntTime format. + This is represented as seconds. + + :param int attr: Object to be serialized. + :return: Deserialized datetime + :rtype: Datetime + :raises DeserializationError: if format invalid + """ + if isinstance(attr, ET.Element): + attr = int(attr.text) # type: ignore + try: + attr = int(attr) + date_obj = datetime.datetime.fromtimestamp(attr, TZ_UTC) + except ValueError as err: + msg = "Cannot deserialize to unix datetime object." + raise DeserializationError(msg) from err + return date_obj diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_validation.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_validation.py new file mode 100644 index 000000000000..f5af3a4eb8a2 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_validation.py @@ -0,0 +1,66 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +import functools + + +def api_version_validation(**kwargs): + params_added_on = kwargs.pop("params_added_on", {}) + method_added_on = kwargs.pop("method_added_on", "") + api_versions_list = kwargs.pop("api_versions_list", []) + + def _index_with_default(value: str, default: int = -1) -> int: + """Get the index of value in lst, or return default if not found. + + :param value: The value to search for in the api_versions_list. + :type value: str + :param default: The default value to return if the value is not found. + :type default: int + :return: The index of the value in the list, or the default value if not found. + :rtype: int + """ + try: + return api_versions_list.index(value) + except ValueError: + return default + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + # this assumes the client has an _api_version attribute + client = args[0] + client_api_version = client._config.api_version # pylint: disable=protected-access + except AttributeError: + return func(*args, **kwargs) + + if _index_with_default(method_added_on) > _index_with_default(client_api_version): + raise ValueError( + f"'{func.__name__}' is not available in API version " + f"{client_api_version}. Pass service API version {method_added_on} or newer to your client." + ) + + unsupported = { + parameter: api_version + for api_version, parameters in params_added_on.items() + for parameter in parameters + if parameter in kwargs and _index_with_default(api_version) > _index_with_default(client_api_version) + } + if unsupported: + raise ValueError( + "".join( + [ + f"'{param}' is not available in API version {client_api_version}. " + f"Use service API version {version} or newer.\n" + for param, version in unsupported.items() + ] + ) + ) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_version.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_version.py new file mode 100644 index 000000000000..be71c81bd282 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/_version.py @@ -0,0 +1,9 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +VERSION = "1.0.0b1" diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/__init__.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/__init__.py new file mode 100644 index 000000000000..3a6cc7aa9809 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/__init__.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import + +from ._client import FineTuningSessionClient # type: ignore + +try: + from ._patch import __all__ as _patch_all + from ._patch import * +except ImportError: + _patch_all = [] +from ._patch import patch_sdk as _patch_sdk + +__all__ = [ + "FineTuningSessionClient", +] +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore + +_patch_sdk() diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/_client.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/_client.py new file mode 100644 index 000000000000..6b6e271d5e1b --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/_client.py @@ -0,0 +1,118 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from copy import deepcopy +from typing import Any, Awaitable, TYPE_CHECKING +from typing_extensions import Self + +from azure.core import AsyncPipelineClient +from azure.core.pipeline import policies +from azure.core.rest import AsyncHttpResponse, HttpRequest + +from .._utils.serialization import Deserializer, Serializer +from ._configuration import FineTuningSessionClientConfiguration +from .operations import CheckpointsOperations, Operations, SamplingOperations, SessionsOperations, TrainingOperations + +if TYPE_CHECKING: + from azure.core.credentials_async import AsyncTokenCredential + + +class FineTuningSessionClient: # pylint: disable=client-accepts-api-version-keyword + """FineTuningSessionClient. + + :ivar sessions: SessionsOperations operations + :vartype sessions: azure.ai.finetuning_sessions.aio.operations.SessionsOperations + :ivar training: TrainingOperations operations + :vartype training: azure.ai.finetuning_sessions.aio.operations.TrainingOperations + :ivar checkpoints: CheckpointsOperations operations + :vartype checkpoints: azure.ai.finetuning_sessions.aio.operations.CheckpointsOperations + :ivar sampling: SamplingOperations operations + :vartype sampling: azure.ai.finetuning_sessions.aio.operations.SamplingOperations + :ivar operations: Operations operations + :vartype operations: azure.ai.finetuning_sessions.aio.operations.Operations + :param endpoint: Foundry Project endpoint in the form + "https://{ai-services-account-name}.services.ai.azure.com/api/projects/{project-name}". If you + only have one Project in your Foundry Hub, or to target the default Project in your Hub, use + the form "https://{ai-services-account-name}.services.ai.azure.com/api/projects/_project". + Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Required. + :type credential: ~azure.core.credentials_async.AsyncTokenCredential + :keyword int polling_interval: Default waiting time between two polls for LRO operations if no + Retry-After header is present. + """ + + def __init__(self, endpoint: str, credential: "AsyncTokenCredential", **kwargs: Any) -> None: + _endpoint = "{endpoint}" + self._config = FineTuningSessionClientConfiguration(endpoint=endpoint, credential=credential, **kwargs) + + _policies = kwargs.pop("policies", None) + if _policies is None: + _policies = [ + policies.RequestIdPolicy(**kwargs), + self._config.headers_policy, + self._config.user_agent_policy, + self._config.proxy_policy, + policies.ContentDecodePolicy(**kwargs), + self._config.redirect_policy, + self._config.retry_policy, + self._config.authentication_policy, + self._config.custom_hook_policy, + self._config.logging_policy, + policies.DistributedTracingPolicy(**kwargs), + policies.SensitiveHeaderCleanupPolicy(**kwargs) if self._config.redirect_policy else None, + self._config.http_logging_policy, + ] + self._client: AsyncPipelineClient = AsyncPipelineClient(base_url=_endpoint, policies=_policies, **kwargs) + + self._serialize = Serializer() + self._deserialize = Deserializer() + self._serialize.client_side_validation = False + self.sessions = SessionsOperations(self._client, self._config, self._serialize, self._deserialize) + self.training = TrainingOperations(self._client, self._config, self._serialize, self._deserialize) + self.checkpoints = CheckpointsOperations(self._client, self._config, self._serialize, self._deserialize) + self.sampling = SamplingOperations(self._client, self._config, self._serialize, self._deserialize) + self.operations = Operations(self._client, self._config, self._serialize, self._deserialize) + + def send_request( + self, request: HttpRequest, *, stream: bool = False, **kwargs: Any + ) -> Awaitable[AsyncHttpResponse]: + """Runs the network request through the client's chained policies. + + >>> from azure.core.rest import HttpRequest + >>> request = HttpRequest("GET", "https://www.example.org/") + + >>> response = await client.send_request(request) + + + For more information on this code flow, see https://aka.ms/azsdk/dpcodegen/python/send_request + + :param request: The network request you want to make. Required. + :type request: ~azure.core.rest.HttpRequest + :keyword bool stream: Whether the response payload will be streamed. Defaults to False. + :return: The response of your network call. Does not do error handling on your response. + :rtype: ~azure.core.rest.AsyncHttpResponse + """ + + request_copy = deepcopy(request) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + request_copy.url = self._client.format_url(request_copy.url, **path_format_arguments) + return self._client.send_request(request_copy, stream=stream, **kwargs) # type: ignore + + async def close(self) -> None: + await self._client.close() + + async def __aenter__(self) -> Self: + await self._client.__aenter__() + return self + + async def __aexit__(self, *exc_details: Any) -> None: + await self._client.__aexit__(*exc_details) diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/_configuration.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/_configuration.py new file mode 100644 index 000000000000..bdbda2b25b1b --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/_configuration.py @@ -0,0 +1,69 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from typing import Any, TYPE_CHECKING + +from azure.core.credentials import AzureKeyCredential +from azure.core.pipeline import policies +from azure.core.pipeline.policies import AzureKeyCredentialPolicy + +from .._version import VERSION + +if TYPE_CHECKING: + from azure.core.credentials_async import AsyncTokenCredential + + +class FineTuningSessionClientConfiguration: # pylint: disable=too-many-instance-attributes + """Configuration for FineTuningSessionClient. + + Note that all parameters used to create this instance are saved as instance + attributes. + + :param endpoint: Foundry Project endpoint in the form + "https://{ai-services-account-name}.services.ai.azure.com/api/projects/{project-name}". If you + only have one Project in your Foundry Hub, or to target the default Project in your Hub, use + the form "https://{ai-services-account-name}.services.ai.azure.com/api/projects/_project". + Required. + :type endpoint: str + :param credential: Credential used to authenticate requests to the service. Required. + :type credential: ~azure.core.credentials_async.AsyncTokenCredential + """ + + def __init__(self, endpoint: str, credential: "AsyncTokenCredential", **kwargs: Any) -> None: + if endpoint is None: + raise ValueError("Parameter 'endpoint' must not be None.") + if credential is None: + raise ValueError("Parameter 'credential' must not be None.") + + self.endpoint = endpoint + self.credential = credential + self.credential_scopes = kwargs.pop("credential_scopes", ["https://ai.azure.com/.default"]) + kwargs.setdefault("sdk_moniker", "finetuning-sessions/{}".format(VERSION)) + self.polling_interval = kwargs.get("polling_interval", 30) + self._configure(**kwargs) + + def _configure(self, **kwargs: Any) -> None: + self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs) + self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs) + self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs) + self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs) + self.http_logging_policy = kwargs.get("http_logging_policy") or policies.HttpLoggingPolicy(**kwargs) + self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs) + self.redirect_policy = kwargs.get("redirect_policy") or policies.AsyncRedirectPolicy(**kwargs) + self.retry_policy = kwargs.get("retry_policy") or policies.AsyncRetryPolicy(**kwargs) + self.authentication_policy = kwargs.get("authentication_policy") + if self.credential and not self.authentication_policy: + if isinstance(self.credential, AzureKeyCredential): + # API key auth: sends "api-key: " header on every request. + self.authentication_policy = AzureKeyCredentialPolicy( + self.credential, name="api-key" + ) + else: + self.authentication_policy = policies.AsyncBearerTokenCredentialPolicy( + self.credential, *self.credential_scopes, **kwargs + ) diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/_patch.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/_patch.py new file mode 100644 index 000000000000..960b3ce3e559 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/_patch.py @@ -0,0 +1,1417 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------- +"""Async convenience methods patched onto FineTuningSessionClient. + +All async training operations (create_session, forward_backward, optim_step, +save_weights, save_weights_for_sampler, sample, close_session) are added +directly to the generated ``FineTuningSessionClient`` class via ``patch_sdk()``. + +Concurrency: + - Heartbeat uses an ``asyncio.Task`` per session. + - Chunked ``forward_backward`` uses ``asyncio.gather``. + - A configurable ``asyncio.Semaphore`` gates concurrent POSTs to prevent + connection storms (default: 64). + +Usage:: + + from azure.ai.finetuning_sessions.aio import FineTuningSessionClient + + async with FineTuningSessionClient(endpoint, credential) as client: + session_id = await client.create_session(base_model="Llama-3.1-8B") + fb = await client.forward_backward(session_id, batch, loss_fn="cross_entropy") + opt = await client.optim_step(session_id, AdamParams(learning_rate=1e-4)) + await client.close_session(session_id) +""" +from __future__ import annotations + +import asyncio +import json as _json +import logging as _logging +import random as _random +import time as _time +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from azure.core.exceptions import ServiceRequestError as _ServiceRequestError +from azure.core.exceptions import ServiceResponseError as _ServiceResponseError +from azure.core.rest import HttpRequest as _HttpRequest + +from .._exceptions import ( + _classify_http_error, + _classify_poll_failure, +) + +from ..models import ( + AdamParams, + CreateSessionRequest, + Datum, + ForwardBackwardInput, + ForwardBackwardOperationResult, + ForwardBackwardRequest, + ForwardInput, + ForwardRequest, + FromCheckpoint, + LoRAConfig, + LossFn, + LossFnConfig, + ModelInput, + ModelInputChunk, + OperationResult, + OptimStepRequest, + SampleRequest, + SamplingParams, + SaveCheckpointRequest, + SaveSamplerWeightsRequest, + FoundryFeaturesOptInKeys, +) +from .._utils.model_base import SdkJSONEncoder as _SdkJSONEncoder, _deserialize as _deserialize_model +from .._patch import ( + _chunk_data, + _combine_fwd_bwd_results, + _normalize_loom_result, + _base_headers, + _log_http, + _LOOM_SUBPATH_TO_OP_TYPE, + _API_VERSION, + _DEFAULT_OPERATION_TIMEOUT_SEC, + _ErrorBudget, + _RETRIEVE_POLL_MIN, + _RETRIEVE_POLL_MAX, + _maybe_log_poll_progress, + _clear_poll_log_state, +) + +if TYPE_CHECKING: + from ._client import FineTuningSessionClient + +_PREVIEW = FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW +_logger = _logging.getLogger(__name__) + +_POLL_LOG_DEDUP_SEC = 30.0 + +#: Default maximum concurrent POST requests. +_DEFAULT_POST_CONCURRENCY = 64 + + +# -- Internal state initializer ----------------------------------------------- + +def _ensure_async_state(self: "FineTuningSessionClient") -> None: + """Lazily initialize async state on the client instance.""" + if not hasattr(self, "_heartbeat_tasks"): + self._heartbeat_tasks: dict[str, asyncio.Task] = {} + if not hasattr(self, "_post_semaphore"): + self._post_semaphore = asyncio.Semaphore(_DEFAULT_POST_CONCURRENCY) + if not hasattr(self, "_sampling_session_seq"): + self._sampling_session_seq: dict[str, int] = {} + + +# -- Background heartbeat ----------------------------------------------------- + +def _start_heartbeat( + self: "FineTuningSessionClient", + session_id: str, + interval_sec: float = 30.0, +) -> None: + """Start an asyncio task that sends heartbeat every interval_sec.""" + _ensure_async_state(self) + raw_id = session_id.removeprefix("model_") + heartbeat_session_id = f"session_{raw_id}" + + async def _heartbeat_loop() -> None: + while True: + await asyncio.sleep(interval_sec) + try: + hb_req = _HttpRequest( + "POST", + "{endpoint}" + + f"/fine_tuning/sessions/{heartbeat_session_id}/heartbeat", + headers=_base_headers(), + params={"api-version": _API_VERSION}, + ) + resp = await self.send_request(hb_req) + if resp.status_code != 200: + _logger.warning( + "[heartbeat] status=%d for %s", + resp.status_code, + heartbeat_session_id, + ) + except asyncio.CancelledError: + return + except Exception as exc: + _logger.warning( + "[heartbeat] failed for %s: %s", + heartbeat_session_id, + exc, + ) + + task = asyncio.create_task( + _heartbeat_loop(), name=f"fts-heartbeat-{session_id}" + ) + self._heartbeat_tasks[session_id] = task + _logger.info( + "[heartbeat] started (interval=%.0fs, session=%s)", + interval_sec, + heartbeat_session_id, + ) + + +def _stop_heartbeat(self: "FineTuningSessionClient", session_id: str) -> None: + """Cancel the background heartbeat task for a session.""" + _ensure_async_state(self) + task = self._heartbeat_tasks.pop(session_id, None) + if task is not None: + task.cancel() + + +# -- Low-level helpers --------------------------------------------------------- + +async def _post( + self: "FineTuningSessionClient", + subpath: str, + body_model: Any, + extra_params: Optional[dict] = None, +) -> tuple[str, str]: + """POST to enqueue a job. Returns ``(request_id, op_type)``. + + Gated by ``_post_semaphore`` to prevent connection storms. + Retries on 408/409/429/5xx with exponential backoff (max 5 retries). + """ + _ensure_async_state(self) + body_json = _json.dumps( + body_model, cls=_SdkJSONEncoder, exclude_readonly=True + ) + post_params: dict = {"api-version": _API_VERSION} + if extra_params: + post_params.update(extra_params) + op_type = _LOOM_SUBPATH_TO_OP_TYPE.get(subpath.rsplit("/", 1)[-1], "") + _log_http("request", "POST", subpath, body=_json.loads(body_json)) + + max_retries = 2 + _BASE_TIMEOUT_SEC = 100 # per-request timeout; escalated on each retry + _TIMEOUT_MULTIPLIER = 1.5 + # Non-retryable status codes: deterministic rejections that will + # never succeed on retry. Surface a typed exception immediately. + _NON_RETRYABLE = frozenset({400, 413, 422}) + + async with self._post_semaphore: + last_status: Optional[int] = None + consecutive_same_status = 0 + + for attempt in range(max_retries + 1): + # Escalate per-request timeout on retries so later attempts + # aren't doomed to the same cutoff when the server is slow. + request_timeout = _BASE_TIMEOUT_SEC * (_TIMEOUT_MULTIPLIER ** attempt) + try: + post_req = _HttpRequest( + "POST", + "{endpoint}" + subpath, + headers=_base_headers({"Content-Type": "application/json"}), + params=post_params, + content=body_json, + ) + resp = await self.send_request( + post_req, connection_timeout=request_timeout + ) + sc = resp.status_code + + # --- Non-retryable: classify and raise immediately --- + if sc in _NON_RETRYABLE: + try: + resp_body = resp.json() + except Exception: + resp_body = None + typed = _classify_http_error(sc, resp_body, response=resp) + if typed is not None: + raise typed + resp.raise_for_status() + + # --- Track repeated same-status for pattern detection --- + if sc in (408, 409, 429) or (500 <= sc < 600): + if sc == last_status: + consecutive_same_status += 1 + else: + consecutive_same_status = 1 + last_status = sc + + # After 2 consecutive identical failures, it's likely + # persistent — classify and raise typed if possible. + if consecutive_same_status >= 2: + try: + resp_body = resp.json() + except Exception: + resp_body = None + typed = _classify_http_error(sc, resp_body, response=resp) + if typed is not None: + raise typed + + if attempt < max_retries: + # Honor Retry-After header from server (seconds). + retry_after = resp.headers.get("Retry-After") + if retry_after is not None: + try: + wait = float(retry_after) + except (ValueError, TypeError): + wait = min(0.5 * (2**attempt), 10.0) + else: + wait = min(0.5 * (2**attempt), 10.0) + # Add jitter to prevent thundering herd. + wait *= 1 - 0.25 * _random.random() + _logger.warning( + "POST %s returned %d, retry %d/%d in %.1fs", + subpath, + sc, + attempt + 1, + max_retries, + wait, + ) + await asyncio.sleep(wait) + continue + + # Exhausted retries — classify before raising generic error. + try: + resp_body = resp.json() + except Exception: + resp_body = None + typed = _classify_http_error(sc, resp_body, response=resp) + if typed is not None: + raise typed + + _log_http( + "response", + "POST", + subpath, + status=sc, + body=resp.json() if sc < 400 else None, + ) + resp.raise_for_status() + data = resp.json() + return ( + data["request_id"], + op_type, + ) + + except (_ServiceRequestError, _ServiceResponseError) as exc: + if attempt < max_retries: + # Back off longer for network errors (timeout/connection + # failures). The actual request timeout is escalated via + # connection_timeout above. + wait = min(1.0 * (2**attempt), 10.0) * ( + 1 - 0.25 * _random.random() + ) + _logger.warning( + "POST %s %s(%s), retry %d/%d in %.1fs", + subpath, + type(exc).__name__, + exc, + attempt + 1, + max_retries, + wait, + ) + await asyncio.sleep(wait) + continue + raise + + # Should not be reached. + raise RuntimeError(f"POST {subpath} failed after {max_retries} retries") + + +async def _poll( + self: "FineTuningSessionClient", + session_id: str, + request_id: str, + op_type: str, + extra_result_fields: Optional[dict] = None, + error_budget_sec: Optional[float] = None, +) -> OperationResult: + """Short-poll the envelope endpoint until the request resolves. + + Adaptive backoff: starts at ``_RETRIEVE_POLL_MIN``, doubles up to + ``_RETRIEVE_POLL_MAX``. Retries 5xx and transient network errors. + + ``error_budget_sec`` is an ERROR budget, not a wall-clock budget (matching + the sync ``_post_and_poll``): healthy pending progress is unbounded and + CLEARS the budget, while a sustained streak of 5xx / 408 / 429 / transient + network errors longer than the budget raises ``TimeoutError``. Pass ``None`` + to disable it (retry forever). + """ + poll_path = f"/fine_tuning/sessions/{session_id}/request/{request_id}" + conn_backoff = 1.0 + poll_backoff = _RETRIEVE_POLL_MIN + error_budget = _ErrorBudget.for_polling( + error_budget_sec, op_type=op_type, request_id=request_id + ) + poll_start = _time.monotonic() + + while True: + try: + poll_req = _HttpRequest( + "GET", + "{endpoint}" + + f"/fine_tuning/sessions/{session_id}/request/{request_id}", + headers=_base_headers(), + params={"api-version": _API_VERSION}, + ) + _log_http("request", "GET", poll_path) + resp = await self.send_request(poll_req) + + if resp.status_code == 200: + envelope = resp.json() + _log_http("response", "GET", poll_path, status=200, body=envelope) + conn_backoff = 1.0 + + status = envelope.get("status") + if status == "pending": + elapsed = _time.monotonic() - poll_start + _maybe_log_poll_progress(envelope, session_id, request_id, op_type, elapsed) + error_budget.clear() + await asyncio.sleep(poll_backoff) + poll_backoff = min(poll_backoff * 2, _RETRIEVE_POLL_MAX) + continue + if status == "failed": + _clear_poll_log_state(session_id, request_id, op_type) + typed = _classify_poll_failure(envelope, session_id=session_id) + if typed is not None: + raise typed + raise RuntimeError( + f"{op_type} request {request_id} failed " + f"[{envelope.get('error_code') or envelope.get('code') or 'unknown'}]: " + f"{envelope.get('error')} " + f"(debug_ref={envelope.get('debug_ref') or 'n/a'})" + ) + if status != "completed": + _clear_poll_log_state(session_id, request_id, op_type) + raise RuntimeError( + f"Unexpected envelope status {status!r} for " + f"{op_type} request {request_id}: {envelope}" + ) + + # completed -- normalize and deserialize. + _clear_poll_log_state(session_id, request_id, op_type) + raw = envelope.get("result") or {} + normalized = _normalize_loom_result(raw, op_type, request_id) + if extra_result_fields: + for k, v in extra_result_fields.items(): + if not normalized.get(k): + normalized[k] = v + return _deserialize_model(OperationResult, normalized) + + # Retryable HTTP status. + if ( + 500 <= resp.status_code < 600 + or resp.status_code in (408, 429) + ): + body: Optional[Any] = None + try: + body = resp.json() + except Exception: + body = None + _log_http( + "response", "GET", poll_path, + status=resp.status_code, body=body, + ) + elapsed = _time.monotonic() - poll_start + _logger.debug( + "[poller] retry on %s/%s after HTTP %d (%.0fs elapsed)", + session_id, request_id, resp.status_code, elapsed, + ) + error_budget.consume(f"HTTP {resp.status_code}") + conn_backoff = 1.0 + # Honor Retry-After header if present. + retry_after = resp.headers.get("Retry-After") + if retry_after is not None: + try: + poll_wait = float(retry_after) + except (ValueError, TypeError): + poll_wait = _RETRIEVE_POLL_MIN + else: + poll_wait = _RETRIEVE_POLL_MIN + await asyncio.sleep(poll_wait) + continue + + # Non-retryable HTTP error. + _log_http( + "response", "GET", poll_path, + status=resp.status_code, body=None, + ) + try: + poll_body = resp.json() + except Exception: + poll_body = None + typed = _classify_http_error( + resp.status_code, poll_body, response=resp, session_id=session_id + ) + if typed is not None: + raise typed + resp.raise_for_status() + + except (_ServiceRequestError, _ServiceResponseError) as exc: + elapsed = _time.monotonic() - poll_start + _logger.warning( + "[poller] retry on %s/%s after %s(%s) (%.0fs elapsed), backoff %.1fs", + session_id, request_id, type(exc).__name__, exc, elapsed, conn_backoff, + ) + error_budget.consume(type(exc).__name__) + await asyncio.sleep(conn_backoff) + conn_backoff = min(conn_backoff * 2, 30.0) + continue + + +async def _post_and_poll( + self: "FineTuningSessionClient", + session_id: str, + subpath: str, + body_model: Any, + extra_params: Optional[dict] = None, + extra_result_fields: Optional[dict] = None, +) -> OperationResult: + """POST to enqueue a job, then poll until it completes.""" + request_id, op_type = await _post( + self, subpath, body_model, extra_params + ) + return await _poll( + self, + session_id, + request_id, + op_type, + extra_result_fields, + error_budget_sec=_DEFAULT_OPERATION_TIMEOUT_SEC, + ) + + +# -- Public methods patched onto FineTuningSessionClient ----------------------- + +async def create_session( + self: "FineTuningSessionClient", + *, + base_model: str, + lora_config: Optional[LoRAConfig] = None, + type: str = "training", + from_checkpoint: Optional[FromCheckpoint] = None, + timeout_sec: float = 600.0, +) -> str: + """Create a fine-tuning session and wait until the model is loaded. + + :param base_model: Name of the base model to load. + :param lora_config: Optional LoRA adapter config. + :param type: Session type string. Defaults to ``"training"``. + :param from_checkpoint: Optional checkpoint to resume from. + :param timeout_sec: Maximum seconds to wait for model load. + :return: The ``session_id`` string (e.g. ``"model_abc123"``). + """ + _ensure_async_state(self) + + body = _json.loads( + _json.dumps( + CreateSessionRequest( + type=type, base_model=base_model, lora_config=lora_config, + ), + cls=_SdkJSONEncoder, + exclude_readonly=True, + ) + ) + if from_checkpoint is not None: + body["from_checkpoint"] = _json.loads( + _json.dumps(from_checkpoint, cls=_SdkJSONEncoder, exclude_readonly=True) + ) + + body_json = _json.dumps(body) + post_req = _HttpRequest( + "POST", + "{endpoint}/fine_tuning/sessions", + headers=_base_headers({"Content-Type": "application/json"}), + params={"api-version": _API_VERSION}, + content=body_json, + ) + _log_http("request", "POST", "/fine_tuning/sessions", body=_json.loads(body_json)) + post_resp = await self.send_request(post_req) + _log_http( + "response", + "POST", + "/fine_tuning/sessions", + status=post_resp.status_code, + body=post_resp.json() if post_resp.status_code < 400 else None, + ) + if post_resp.status_code >= 400: + try: + resp_body = post_resp.json() + except Exception: + resp_body = None + typed = _classify_http_error(post_resp.status_code, resp_body, response=post_resp) + if typed is not None: + raise typed + post_resp.raise_for_status() + data = post_resp.json() + raw_session_id: str = data["session_id"] + request_id: str = data["request_id"] + _logger.info( + "[create_session] POST response: raw_session_id=%s, request_id=%s", + raw_session_id, + request_id, + ) + + session_id: str = f"model_{raw_session_id}" + _logger.info( + "[create_session] session_id transformed: raw=%s -> resource_id=%s", + raw_session_id, + session_id, + ) + + # Poll until model load completes. + deadline = _time.monotonic() + timeout_sec + conn_backoff = 1.0 + poll_backoff = _RETRIEVE_POLL_MIN + _create_poll_start = _time.monotonic() + while True: + try: + poll_req = _HttpRequest( + "GET", + "{endpoint}" + + f"/fine_tuning/sessions/{session_id}/request/{request_id}", + headers=_base_headers(), + params={"api-version": _API_VERSION}, + ) + poll_path = ( + f"/fine_tuning/sessions/{session_id}/request/{request_id}" + ) + _log_http("request", "GET", poll_path) + poll_resp = await self.send_request(poll_req) + envelope = ( + poll_resp.json() if poll_resp.status_code == 200 else None + ) + _log_http( + "response", "GET", poll_path, + status=poll_resp.status_code, body=envelope, + ) + + if poll_resp.status_code == 200: + env_status = envelope.get("status") + if env_status == "completed": + _logger.info("[create_session] model load completed: %s", envelope) + _clear_poll_log_state(session_id, request_id, "create_session") + break + if env_status == "failed": + _clear_poll_log_state(session_id, request_id, "create_session") + typed = _classify_poll_failure(envelope, session_id=session_id) + if typed is not None: + raise typed + raise RuntimeError( + f"Model load failed for session_id={raw_session_id} " + f"[{envelope.get('error_code') or envelope.get('code') or 'unknown'}]: " + f"{envelope.get('error') or 'unknown error'} " + f"(debug_ref={envelope.get('debug_ref') or 'n/a'})" + ) + # pending -> adaptive backoff + if _time.monotonic() > deadline: + raise RuntimeError( + f"Timed out after {timeout_sec}s waiting for " + f"session_id={raw_session_id} to become ready" + ) + elapsed = _time.monotonic() - _create_poll_start + _maybe_log_poll_progress(envelope, session_id, request_id, "create_session", elapsed) + conn_backoff = 1.0 + await asyncio.sleep(poll_backoff) + poll_backoff = min(poll_backoff * 2, _RETRIEVE_POLL_MAX) + continue + + # Retryable HTTP status codes. + if ( + 500 <= poll_resp.status_code < 600 + or poll_resp.status_code in (408, 429) + ): + if _time.monotonic() > deadline: + raise RuntimeError( + f"Timed out after {timeout_sec}s waiting for " + f"session_id={raw_session_id} to become ready" + ) + elapsed = _time.monotonic() - _create_poll_start + _logger.debug( + "[poller] retry on %s/%s after HTTP %d (%.0fs elapsed)", + session_id, request_id, poll_resp.status_code, elapsed, + ) + conn_backoff = 1.0 + await asyncio.sleep(_RETRIEVE_POLL_MIN) + continue + + # Non-retryable error. + try: + poll_body = poll_resp.json() + except Exception: + poll_body = None + typed = _classify_http_error( + poll_resp.status_code, poll_body, response=poll_resp, session_id=session_id + ) + if typed is not None: + raise typed + poll_resp.raise_for_status() + + except (_ServiceRequestError, _ServiceResponseError) as exc: + if _time.monotonic() > deadline: + raise RuntimeError( + f"Timed out after {timeout_sec}s waiting for " + f"session_id={raw_session_id} to become ready" + ) from exc + elapsed = _time.monotonic() - _create_poll_start + _logger.warning( + "[poller] retry on %s/%s after %s(%s) (%.0fs elapsed), backoff %.1fs", + session_id, request_id, type(exc).__name__, exc, elapsed, conn_backoff, + ) + await asyncio.sleep(conn_backoff) + conn_backoff = min(conn_backoff * 2, 30.0) + continue + + _start_heartbeat(self, session_id) + return session_id + + +async def create_session_from_checkpoint( + self: "FineTuningSessionClient", + *, + checkpoint_path: str, + base_model: str, + lora_config: Optional[LoRAConfig] = None, + type: str = "training", + timeout_sec: float = 600.0, +) -> str: + """Create a session resumed from a previously saved training checkpoint. + + :param checkpoint_path: Format: ``"/"``. + :param base_model: Base model name. + :param lora_config: Optional LoRA config override. + :param type: Session type. Defaults to ``"training"``. + :param timeout_sec: Maximum seconds to wait for model load. + :return: The ``session_id`` string. + """ + parts = checkpoint_path.split("/") + if len(parts) != 2 or not parts[0] or not parts[1]: + raise ValueError( + "checkpoint_path must be '/' " + f"with exactly one '/' separator, got: {checkpoint_path!r}" + ) + source_session_id, checkpoint_id = parts + if not source_session_id.startswith("model_"): + source_session_id = f"model_{source_session_id}" + return await create_session( + self, + base_model=base_model, + lora_config=lora_config, + type=type, + from_checkpoint=FromCheckpoint( + source_session_id=source_session_id, + checkpoint_id=checkpoint_id, + ), + timeout_sec=timeout_sec, + ) + + +# -- Training ------------------------------------------------------------------ + +async def forward_backward( + self: "FineTuningSessionClient", + session_id: str, + batch: List[Datum], + *, + loss_fn: Union[str, LossFn] = LossFn.CROSS_ENTROPY, + loss_fn_config: Optional[LossFnConfig] = None, +) -> OperationResult: + """Submit a mini-batch for a forward + backward pass. + + Automatically chunks large batches and submits chunks in parallel + using ``asyncio.gather``. + + :param session_id: The session ID returned by ``create_session``. + :param batch: List of Datum. + :param loss_fn: Loss function name. Defaults to ``"cross_entropy"``. + :param loss_fn_config: Optional per-loss hyper-parameters. + :return: OperationResult. + """ + chunks = _chunk_data(batch) + if len(chunks) <= 1: + return await _post_and_poll( + self, + session_id, + f"/fine_tuning/sessions/{session_id}/forward_backward", + ForwardBackwardRequest( + forward_backward_input=ForwardBackwardInput( + data=batch, + loss_fn=loss_fn, + loss_fn_config=loss_fn_config, + ) + ), + ) + + _logger.info( + "[forward_backward] batch of %d datums split into %d chunks: %s", + len(batch), + len(chunks), + [len(c) for c in chunks], + ) + + async def _submit_chunk( + i: int, chunk: List[Datum] + ) -> ForwardBackwardOperationResult: + _logger.info( + "[forward_backward] sending chunk %d/%d (%d datums)", + i + 1, + len(chunks), + len(chunk), + ) + result = await _post_and_poll( + self, + session_id, + f"/fine_tuning/sessions/{session_id}/forward_backward", + ForwardBackwardRequest( + forward_backward_input=ForwardBackwardInput( + data=chunk, + loss_fn=loss_fn, + loss_fn_config=loss_fn_config, + ) + ), + ) + if isinstance(result, ForwardBackwardOperationResult): + return result + return ForwardBackwardOperationResult( + total_loss=getattr(result, "total_loss", 0.0), + per_datum_logprobs=getattr(result, "per_datum_logprobs", None), + metrics=getattr(result, "metrics", None), + ) + + # Fire all chunks in parallel. + chunk_results = await asyncio.gather( + *(_submit_chunk(i, chunk) for i, chunk in enumerate(chunks)) + ) + chunk_sizes = [len(c) for c in chunks] + return _combine_fwd_bwd_results(list(chunk_results), chunk_sizes) + + +async def forward( + self: "FineTuningSessionClient", + session_id: str, + batch: List[Datum], + *, + loss_fn: Union[str, LossFn] = LossFn.CROSS_ENTROPY, + loss_fn_config: Optional[LossFnConfig] = None, +) -> OperationResult: + """Submit a mini-batch for a forward-only pass (no gradients). + + Automatically chunks large batches and submits chunks in parallel + using ``asyncio.gather``. + + :param session_id: The session ID returned by ``create_session``. + :param batch: List of Datum. + :param loss_fn: Loss function name. Defaults to ``"cross_entropy"``. + :param loss_fn_config: Optional per-loss hyper-parameters. + :return: OperationResult. + """ + chunks = _chunk_data(batch) + if len(chunks) <= 1: + return await _post_and_poll( + self, + session_id, + f"/fine_tuning/sessions/{session_id}/forward", + ForwardRequest( + forward_input=ForwardInput( + data=batch, + loss_fn=loss_fn, + loss_fn_config=loss_fn_config, + ) + ), + ) + + _logger.info( + "[forward] batch of %d datums split into %d chunks: %s", + len(batch), + len(chunks), + [len(c) for c in chunks], + ) + + async def _submit_chunk( + i: int, chunk: List[Datum] + ) -> ForwardBackwardOperationResult: + _logger.info( + "[forward] sending chunk %d/%d (%d datums)", + i + 1, + len(chunks), + len(chunk), + ) + result = await _post_and_poll( + self, + session_id, + f"/fine_tuning/sessions/{session_id}/forward", + ForwardRequest( + forward_input=ForwardInput( + data=chunk, + loss_fn=loss_fn, + loss_fn_config=loss_fn_config, + ) + ), + ) + if isinstance(result, ForwardBackwardOperationResult): + return result + return ForwardBackwardOperationResult( + total_loss=getattr(result, "total_loss", 0.0), + per_datum_logprobs=getattr(result, "per_datum_logprobs", None), + metrics=getattr(result, "metrics", None), + ) + + # Fire all chunks in parallel. + chunk_results = await asyncio.gather( + *(_submit_chunk(i, chunk) for i, chunk in enumerate(chunks)) + ) + chunk_sizes = [len(c) for c in chunks] + return _combine_fwd_bwd_results(list(chunk_results), chunk_sizes) + + +async def forward_post( + self: "FineTuningSessionClient", + session_id: str, + batch: List[Datum], + *, + loss_fn: Union[str, LossFn] = LossFn.CROSS_ENTROPY, + loss_fn_config: Optional[LossFnConfig] = None, +) -> "PendingRequests": + """POST a forward-only pass without polling for completion. + + Automatically chunks large batches. Each chunk's POST is awaited + sequentially so the server assigns monotonically increasing UUID v7 + request IDs. Returns a :class:`PendingRequests` handle whose + ``poll_result()`` can be awaited later. + + :param session_id: The session ID. + :param batch: List of Datum. + :param loss_fn: Loss function name. + :param loss_fn_config: Optional per-loss hyper-parameters. + :return: PendingRequests handle. + """ + subpath = f"/fine_tuning/sessions/{session_id}/forward" + chunks = _chunk_data(batch) + + if len(chunks) > 1: + _logger.info( + "[forward_post] batch of %d datums split into %d chunks: %s", + len(batch), + len(chunks), + [len(c) for c in chunks], + ) + + posted: List[tuple] = [] + for chunk in chunks: + request_id, op_type = await _post( + self, + subpath, + ForwardRequest( + forward_input=ForwardInput( + data=chunk, + loss_fn=loss_fn, + loss_fn_config=loss_fn_config, + ) + ), + ) + posted.append((request_id, op_type)) + + return PendingRequests(self, session_id, posted, chunks) + + +async def forward_async( + self: "FineTuningSessionClient", + session_id: str, + batch: List[Datum], + *, + loss_fn: Union[str, LossFn] = LossFn.CROSS_ENTROPY, + loss_fn_config: Optional[LossFnConfig] = None, +) -> "asyncio.Task[OperationResult]": + """Submit a forward-only pass, return an asyncio Task for the result. + + Awaits all POSTs (with chunking) so that the server assigns UUID v7 + request IDs *before* this method returns. + + **Multi-chunk correctness:** When the batch is split into multiple HTTP + chunks, this method awaits GPU completion of *all* chunks before + returning, matching the forward_backward_async behavior. + + :param session_id: The session ID. + :param batch: List of Datum. + :param loss_fn: Loss function name. Defaults to ``"cross_entropy"``. + :param loss_fn_config: Optional per-loss hyper-parameters. + :return: An asyncio.Task whose result is an OperationResult. + """ + pending = await forward_post( + self, session_id, batch, loss_fn=loss_fn, loss_fn_config=loss_fn_config + ) + + if len(pending._posted) > 1: + _logger.info( + "[forward_async] multi-chunk (%d): awaiting all chunk results before returning", + len(pending._posted), + ) + result = await pending.poll_result() + done: asyncio.Future[OperationResult] = asyncio.get_running_loop().create_future() + done.set_result(result) + return done # type: ignore[return-value] + + return asyncio.create_task(pending.poll_result(), name="fwd_poll") + + +async def optim_step( + self: "FineTuningSessionClient", + session_id: str, + adam_params: AdamParams, +) -> OperationResult: + """Apply accumulated gradients with Adam. + + :param session_id: The session ID. + :param adam_params: Optimizer hyper-parameters. + :return: OperationResult. + """ + return await _post_and_poll( + self, + session_id, + f"/fine_tuning/sessions/{session_id}/optim_step", + OptimStepRequest(adam_params=adam_params), + ) + + +# -- POST-only variants (for pipelined training) ------------------------------- + + +class PendingRequests: + """Opaque handle returned by ``*_post`` methods. + + Holds the request IDs assigned by the server so that + ``poll_result`` can wait for completion later. This lets callers + guarantee POST ordering (UUID v7) while backgrounding the poll. + """ + + def __init__( + self, + client: "FineTuningSessionClient", + session_id: str, + posted: List[tuple], + chunks: Optional[List[List[Datum]]] = None, + extra_result_fields: Optional[dict] = None, + ): + self._client = client + self._session_id = session_id + self._posted = posted # list of (request_id, op_type) + self._chunks = chunks # only set for chunked forward_backward + self._extra_result_fields = extra_result_fields + + async def poll_result(self) -> OperationResult: + """Poll until all POSTed requests complete and return the combined result.""" + if len(self._posted) == 1: + rid, ot = self._posted[0] + return await _poll( + self._client, self._session_id, rid, ot, + extra_result_fields=self._extra_result_fields, + error_budget_sec=_DEFAULT_OPERATION_TIMEOUT_SEC, + ) + + # Multiple chunks — poll in parallel and combine. + async def _poll_one(rid: str, ot: str) -> ForwardBackwardOperationResult: + result = await _poll( + self._client, self._session_id, rid, ot, + extra_result_fields=self._extra_result_fields, + error_budget_sec=_DEFAULT_OPERATION_TIMEOUT_SEC, + ) + if isinstance(result, ForwardBackwardOperationResult): + return result + return ForwardBackwardOperationResult( + total_loss=getattr(result, "total_loss", 0.0), + per_datum_logprobs=getattr(result, "per_datum_logprobs", None), + metrics=getattr(result, "metrics", None), + ) + + chunk_results = await asyncio.gather( + *(_poll_one(rid, ot) for rid, ot in self._posted) + ) + chunk_sizes = [len(c) for c in self._chunks] if self._chunks else [1] * len(self._posted) + return _combine_fwd_bwd_results(list(chunk_results), chunk_sizes) + + +async def forward_backward_post( + self: "FineTuningSessionClient", + session_id: str, + batch: List[Datum], + *, + loss_fn: Union[str, LossFn] = LossFn.CROSS_ENTROPY, + loss_fn_config: Optional[LossFnConfig] = None, +) -> PendingRequests: + """POST a forward+backward pass without polling for completion. + + Automatically chunks large batches. Each chunk's POST is awaited + sequentially so the server assigns monotonically increasing UUID v7 + request IDs. Returns a :class:`PendingRequests` handle whose + ``poll_result()`` can be awaited later. + + Use this instead of :meth:`forward_backward` when you need to + guarantee that all forward_backward requests are registered on the + server *before* a subsequent ``optim_step_post``. + + :param session_id: The session ID. + :param batch: List of Datum. + :param loss_fn: Loss function name. + :param loss_fn_config: Optional per-loss hyper-parameters. + :return: PendingRequests handle. + """ + subpath = f"/fine_tuning/sessions/{session_id}/forward_backward" + chunks = _chunk_data(batch) + + if len(chunks) > 1: + _logger.info( + "[forward_backward_post] batch of %d datums split into %d chunks: %s", + len(batch), + len(chunks), + [len(c) for c in chunks], + ) + + posted: List[tuple] = [] + for chunk in chunks: + request_id, op_type = await _post( + self, + subpath, + ForwardBackwardRequest( + forward_backward_input=ForwardBackwardInput( + data=chunk, + loss_fn=loss_fn, + loss_fn_config=loss_fn_config, + ) + ), + ) + posted.append((request_id, op_type)) + + return PendingRequests(self, session_id, posted, chunks) + + +async def forward_backward_async( + self: "FineTuningSessionClient", + session_id: str, + batch: List[Datum], + *, + loss_fn: Union[str, LossFn] = LossFn.CROSS_ENTROPY, + loss_fn_config: Optional[LossFnConfig] = None, +) -> "asyncio.Task[OperationResult]": + """Submit a forward+backward pass, return an asyncio Task for the result. + + Awaits all POSTs (with chunking) so that the server assigns UUID v7 + request IDs *before* this method returns. + + **Multi-chunk correctness:** When the batch is split into multiple HTTP + chunks, this method awaits GPU completion of *all* chunks before + returning. This ensures that a subsequent ``optim_step_async`` POST + cannot reach the engine until every chunk's gradients have been + accumulated. The returned task resolves immediately with the combined + result. + + For single-chunk batches (the common case), only the poll phase runs + in the background — fully pipelined, no extra latency. + + :param session_id: The session ID. + :param batch: List of Datum. + :param loss_fn: Loss function name. Defaults to ``"cross_entropy"``. + :param loss_fn_config: Optional per-loss hyper-parameters. + :return: An asyncio.Task whose result is an OperationResult. + """ + pending = await forward_backward_post( + self, session_id, batch, loss_fn=loss_fn, loss_fn_config=loss_fn_config + ) + + if len(pending._posted) > 1: + # Multiple chunks: await GPU completion of all chunks NOW so that + # the caller can safely post optim_step after this returns. + # + # Why: the engine polls the DB on a 0.5-2s cadence. If we fire + # chunk1-POST, chunk2-POST, optim-POST in ~600ms, the engine can + # poll between chunk1 and chunk2, see chunk1 with no barrier, and + # process chunk1 + optim_step before chunk2 lands — applying + # gradients from only half the batch. + # + # Waiting for GPU completion of all chunks before posting + # optim_step eliminates the race at the cost of one extra engine + # poll cycle (~0.5-2s) of latency. + _logger.info( + "[forward_backward_async] multi-chunk (%d): awaiting all chunk results before returning", + len(pending._posted), + ) + result = await pending.poll_result() + done: asyncio.Future[OperationResult] = asyncio.get_running_loop().create_future() + done.set_result(result) + return done # type: ignore[return-value] # Future is awaitable like Task + + return asyncio.create_task(pending.poll_result(), name="fwd_bwd_poll") + + +async def optim_step_post( + self: "FineTuningSessionClient", + session_id: str, + adam_params: AdamParams, +) -> PendingRequests: + """POST an optim_step without polling for completion. + + Returns a :class:`PendingRequests` handle whose ``poll_result()`` + can be awaited later. + + :param session_id: The session ID. + :param adam_params: Optimizer hyper-parameters. + :return: PendingRequests handle. + """ + subpath = f"/fine_tuning/sessions/{session_id}/optim_step" + request_id, op_type = await _post( + self, + subpath, + OptimStepRequest(adam_params=adam_params), + ) + return PendingRequests(self, session_id, [(request_id, op_type)]) + + +async def optim_step_async( + self: "FineTuningSessionClient", + session_id: str, + adam_params: AdamParams, +) -> "asyncio.Task[OperationResult]": + """Submit an optim_step, return an asyncio Task for the result. + + Awaits the POST so the server assigns a UUID v7 *after* all preceding + forward_backward_async calls. Only the poll runs in the background. + + :param session_id: The session ID. + :param adam_params: Optimizer hyper-parameters. + :return: An asyncio.Task whose result is an OperationResult. + """ + pending = await optim_step_post(self, session_id, adam_params) + return asyncio.create_task(pending.poll_result(), name="optim_poll") + + +# -- Checkpoints --------------------------------------------------------------- + +async def save_weights( + self: "FineTuningSessionClient", + session_id: str, + path: str, +) -> OperationResult: + """Save a training checkpoint (LoRA weights + optimizer state). + + :param session_id: The session ID. + :param path: Checkpoint name/path. + :return: OperationResult. + """ + return await _post_and_poll( + self, + session_id, + f"/fine_tuning/sessions/{session_id}/checkpoint", + SaveCheckpointRequest(path=path), + ) + + +async def save_weights_post( + self: "FineTuningSessionClient", + session_id: str, + path: str, + *, + step_number: Optional[int] = None, + metrics: Optional[Dict[str, Any]] = None, +) -> "PendingRequests": + """POST a save-weights request without polling for completion. + + :param session_id: The session ID. + :param path: Checkpoint name/path. + :param step_number: Training step number for this checkpoint. + :param metrics: Evaluation metrics at checkpoint time. + :return: PendingRequests handle. + """ + subpath = f"/fine_tuning/sessions/{session_id}/checkpoint" + request_id, op_type = await _post( + self, + subpath, + SaveCheckpointRequest(path=path, step_number=step_number, metrics=metrics), + ) + return PendingRequests(self, session_id, [(request_id, op_type)]) + + +async def save_weights_async( + self: "FineTuningSessionClient", + session_id: str, + path: str, + *, + step_number: Optional[int] = None, + metrics: Optional[Dict[str, Any]] = None, +) -> "asyncio.Task[OperationResult]": + """Submit a save-weights request, return an asyncio Task for the result. + + Awaits the POST so the server assigns a UUID v7 *after* all preceding + requests. Only the poll runs in the background. + + :param session_id: The session ID. + :param path: Checkpoint name/path. + :param step_number: Training step number for this checkpoint. + :param metrics: Evaluation metrics at checkpoint time. + :return: An asyncio.Task whose result is an OperationResult. + """ + pending = await save_weights_post(self, session_id, path, step_number=step_number, metrics=metrics) + return asyncio.create_task(pending.poll_result(), name=f"save_{path}") + + +async def _save_weights_for_sampler_post( + self: "FineTuningSessionClient", + session_id: str, + *, + sampling_session_seq_id: Optional[int] = None, + path: Optional[str] = None, +) -> "PendingRequests": + """Internal: POST a save-weights-for-sampler request without polling.""" + subpath = f"/fine_tuning/sessions/{session_id}/checkpoint_sample" + request_id, op_type = await _post( + self, + subpath, + SaveSamplerWeightsRequest( + seq_id=0, + sampling_session_seq_id=sampling_session_seq_id, + path=path, + ), + ) + return PendingRequests( + self, session_id, [(request_id, op_type)], + extra_result_fields={"checkpoint_id": path or ""}, + ) + + +async def save_weights_for_sampler_async( + self: "FineTuningSessionClient", + session_id: str, + name: str, +) -> "asyncio.Task[OperationResult]": + """Save sampler weights and persist them to blob storage. + + The engine persists the checkpoint because ``sampling_session_seq_id`` + is not set. + + :param session_id: The session ID. + :param name: Checkpoint name/identifier. + :return: An asyncio.Task whose result is an OperationResult with + ``checkpoint_id`` set to *name*. + """ + _ensure_async_state(self) + pending = await _save_weights_for_sampler_post( + self, session_id, path=name, + ) + return asyncio.create_task(pending.poll_result(), name=f"sampler_{name}") + + +async def save_weights_and_get_sampling_client_async( + self: "FineTuningSessionClient", + session_id: str, + name: str, +) -> "asyncio.Task[OperationResult]": + """Sync current LoRA weights to the sampler (ephemeral — not persisted). + + Used every training step to push weights for rollout sampling. + The engine skips blob persistence because ``sampling_session_seq_id`` + is set. The SDK maintains an internal per-session counter — the user + never sees it. + + :param session_id: The session ID. + :param name: Checkpoint name/identifier (e.g. ``"step5"``). + :return: An asyncio.Task whose result is an OperationResult with + ``checkpoint_id`` set to *name*. + """ + _ensure_async_state(self) + seq = self._sampling_session_seq.get(session_id, 0) + 1 + self._sampling_session_seq[session_id] = seq + pending = await _save_weights_for_sampler_post( + self, session_id, + sampling_session_seq_id=seq, path=name, + ) + return asyncio.create_task(pending.poll_result(), name=f"sync_sampler_{name}") + + +# -- Sampling ------------------------------------------------------------------ + +async def sample( + self: "FineTuningSessionClient", + session_id: str, + prompt_tokens: List[int], + sampling_params: SamplingParams, + *, + checkpoint_id: str, + num_samples: int = 1, + sampling_session_id: Optional[str] = None, + seq_id: Optional[int] = None, + prompt_logprobs: bool = False, + topk_prompt_logprobs: int = 0, +) -> OperationResult: + """Generate completions using current LoRA weights. + + :param session_id: The session ID. + :param prompt_tokens: Tokenised prompt as a list of integer IDs. + :param sampling_params: Generation parameters. + :param checkpoint_id: Sampler checkpoint ID from ``save_weights_for_sampler``. + :param num_samples: Number of completions. Default 1. + :return: OperationResult. + """ + return await _post_and_poll( + self, + session_id, + f"/fine_tuning/sessions/{session_id}/sample", + SampleRequest( + num_samples=num_samples, + prompt=ModelInput(chunks=[ModelInputChunk(tokens=prompt_tokens)]), + sampling_params=sampling_params, + topk_prompt_logprobs=topk_prompt_logprobs, + sampling_session_id=sampling_session_id, + seq_id=seq_id, + prompt_logprobs=prompt_logprobs, + ), + extra_params={"checkpoint_id": checkpoint_id}, + ) + + +# -- Session lifecycle --------------------------------------------------------- + +async def close_session( + self: "FineTuningSessionClient", + session_id: str, +) -> None: + """Unload the session from the GPU engine. + + Stops the background heartbeat, then issues the complete request. + + :param session_id: The session ID to close. + """ + _stop_heartbeat(self, session_id) + close_req = _HttpRequest( + "POST", + "{endpoint}" + f"/fine_tuning/sessions/{session_id}/complete", + headers=_base_headers(), + params={"api-version": _API_VERSION}, + ) + resp = await self.send_request(close_req) + resp.raise_for_status() + + +# -- Patch the generated client ------------------------------------------------ + +__all__: list[str] = [] + + +def patch_sdk(): + """Patch async convenience methods onto FineTuningSessionClient.""" + from ._client import FineTuningSessionClient + + FineTuningSessionClient.create_session = create_session + FineTuningSessionClient.create_session_from_checkpoint = create_session_from_checkpoint + FineTuningSessionClient.forward_backward = forward_backward + FineTuningSessionClient.forward_backward_post = forward_backward_post + FineTuningSessionClient.forward_backward_async = forward_backward_async + FineTuningSessionClient.forward = forward + FineTuningSessionClient.forward_post = forward_post + FineTuningSessionClient.forward_async = forward_async + FineTuningSessionClient.optim_step = optim_step + FineTuningSessionClient.optim_step_post = optim_step_post + FineTuningSessionClient.optim_step_async = optim_step_async + FineTuningSessionClient.save_weights = save_weights + FineTuningSessionClient.save_weights_post = save_weights_post + FineTuningSessionClient.save_weights_async = save_weights_async + FineTuningSessionClient.save_weights_for_sampler_async = save_weights_for_sampler_async + FineTuningSessionClient.save_weights_and_get_sampling_client_async = save_weights_and_get_sampling_client_async + FineTuningSessionClient.sample = sample + FineTuningSessionClient.close_session = close_session diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/operations/__init__.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/operations/__init__.py new file mode 100644 index 000000000000..8f152cec5e7e --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/operations/__init__.py @@ -0,0 +1,33 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import + +from ._operations import SessionsOperations # type: ignore +from ._operations import TrainingOperations # type: ignore +from ._operations import CheckpointsOperations # type: ignore +from ._operations import SamplingOperations # type: ignore +from ._operations import Operations # type: ignore + +from ._patch import __all__ as _patch_all +from ._patch import * +from ._patch import patch_sdk as _patch_sdk + +__all__ = [ + "SessionsOperations", + "TrainingOperations", + "CheckpointsOperations", + "SamplingOperations", + "Operations", +] +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore +_patch_sdk() diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/operations/_operations.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/operations/_operations.py new file mode 100644 index 000000000000..51ba49cfa3c7 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/operations/_operations.py @@ -0,0 +1,2614 @@ +# pylint: disable=too-many-lines +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +from collections.abc import MutableMapping +from io import IOBase +import json +from typing import Any, AsyncIterator, Callable, IO, Literal, Optional, TypeVar, Union, cast, overload + +from azure.core import AsyncPipelineClient +from azure.core.exceptions import ( + ClientAuthenticationError, + HttpResponseError, + ResourceExistsError, + ResourceNotFoundError, + ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, + map_error, +) +from azure.core.pipeline import PipelineResponse +from azure.core.polling import AsyncLROPoller, AsyncNoPolling, AsyncPollingMethod +from azure.core.polling.async_base_polling import AsyncLROBasePolling +from azure.core.rest import AsyncHttpResponse, HttpRequest +from azure.core.tracing.decorator_async import distributed_trace_async +from azure.core.utils import case_insensitive_dict + +from ... import models as _models +from ..._utils.model_base import SdkJSONEncoder, _deserialize, _failsafe_deserialize +from ..._utils.serialization import Deserializer, Serializer +from ..._validation import api_version_validation +from ...models._enums import FoundryFeaturesOptInKeys +from ...operations._operations import ( + build_checkpoints_get_request, + build_checkpoints_list_request, + build_checkpoints_save_request, + build_checkpoints_save_sampler_weights_request, + build_operations_get_request, + build_sampling_sample_request, + build_sessions_create_request, + build_sessions_get_request, + build_sessions_heartbeat_request, + build_sessions_list_request, + build_sessions_unload_request, + build_training_forward_backward_request, + build_training_optim_step_request, +) +from .._configuration import FineTuningSessionClientConfiguration + +JSON = MutableMapping[str, Any] +T = TypeVar("T") +ClsType = Optional[Callable[[PipelineResponse[HttpRequest, AsyncHttpResponse], T, dict[str, Any]], Any]] +List = list + + +class SessionsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~azure.ai.finetuning_sessions.aio.FineTuningSessionClient`'s + :attr:`sessions` attribute. + """ + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: FineTuningSessionClientConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def _create_initial( + self, + body: Union[_models.CreateSessionRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> AsyncIterator[bytes]: + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_sessions_create_request( + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = True + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + response_headers = {} + response_headers["Operation-Location"] = self._deserialize("str", response.headers.get("Operation-Location")) + + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + async def begin_create( + self, + body: _models.CreateSessionRequest, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Create a fine-tuning session. + + Create a new fine-tuning session and allocate it to a GPU engine. + + Returns ``session_id`` array ``["session_xxx", "sampling_xxx"]``. + Use ``session_xxx`` as ``{sessionId}`` for training ops and + ``sampling_xxx`` as ``{samplingId}`` for sampling ops. + + :param body: Required. + :type body: ~azure.ai.finetuning_sessions.models.CreateSessionRequest + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def begin_create( + self, + body: JSON, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Create a fine-tuning session. + + Create a new fine-tuning session and allocate it to a GPU engine. + + Returns ``session_id`` array ``["session_xxx", "sampling_xxx"]``. + Use ``session_xxx`` as ``{sessionId}`` for training ops and + ``sampling_xxx`` as ``{samplingId}`` for sampling ops. + + :param body: Required. + :type body: JSON + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def begin_create( + self, + body: IO[bytes], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Create a fine-tuning session. + + Create a new fine-tuning session and allocate it to a GPU engine. + + Returns ``session_id`` array ``["session_xxx", "sampling_xxx"]``. + Use ``session_xxx`` as ``{sessionId}`` for training ops and + ``sampling_xxx`` as ``{samplingId}`` for sampling ops. + + :param body: Required. + :type body: IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def begin_create( + self, + body: Union[_models.CreateSessionRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Create a fine-tuning session. + + Create a new fine-tuning session and allocate it to a GPU engine. + + Returns ``session_id`` array ``["session_xxx", "sampling_xxx"]``. + Use ``session_xxx`` as ``{sessionId}`` for training ops and + ``sampling_xxx`` as ``{samplingId}`` for sampling ops. + + :param body: Is one of the following types: CreateSessionRequest, JSON, IO[bytes] Required. + :type body: ~azure.ai.finetuning_sessions.models.CreateSessionRequest or JSON or IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.OperationResult] = kwargs.pop("cls", None) + polling: Union[bool, AsyncPollingMethod] = kwargs.pop("polling", True) + lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) + cont_token: Optional[str] = kwargs.pop("continuation_token", None) + if cont_token is None: + raw_result = await self._create_initial( + body=body, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + cls=lambda x, y, z: x, + headers=_headers, + params=_params, + **kwargs + ) + await raw_result.http_response.read() # type: ignore + kwargs.pop("error_map", None) + + def get_long_running_output(pipeline_response): + response_headers = {} + response = pipeline_response.http_response + response_headers["Operation-Location"] = self._deserialize( + "str", response.headers.get("Operation-Location") + ) + + deserialized = _deserialize(_models.OperationResult, response.json().get("result", {})) + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + return deserialized + + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + if polling is True: + polling_method: AsyncPollingMethod = cast( + AsyncPollingMethod, + AsyncLROBasePolling(lro_delay, path_format_arguments=path_format_arguments, **kwargs), + ) + elif polling is False: + polling_method = cast(AsyncPollingMethod, AsyncNoPolling()) + else: + polling_method = polling + if cont_token: + return AsyncLROPoller[_models.OperationResult].from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output, + ) + return AsyncLROPoller[_models.OperationResult]( + self._client, raw_result, get_long_running_output, polling_method # type: ignore + ) + + @distributed_trace_async + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def create( + self, + body: Union[_models.CreateSessionRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> JSON: + """Create a fine-tuning session (synchronous 200 response). + + POSTs to ``/fine_tuning/sessions`` and returns the response body directly. + Use this instead of ``begin_create`` when the server returns HTTP 200 (not 202 LRO). + + :param body: Is one of the following types: CreateSessionRequest, JSON, IO[bytes] Required. + :type body: ~azure.ai.finetuning_sessions.models.CreateSessionRequest or JSON or IO[bytes] + :keyword foundry_features: Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: dict (JSON response body) + :rtype: dict + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[JSON] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_sessions_create_request( + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + await response.read() + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize(_models.ApiErrorResponse, response) + raise HttpResponseError(response=response, model=error) + + deserialized: JSON = response.json() + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def list( + self, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + limit: Optional[int] = None, + offset: Optional[int] = None, + **kwargs: Any + ) -> _models.SessionList: + """List fine-tuning sessions. + + List all fine-tuning sessions for this project. + + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword limit: Default value is None. + :paramtype limit: int + :keyword offset: Default value is None. + :paramtype offset: int + :return: SessionList. The SessionList is compatible with MutableMapping + :rtype: ~azure.ai.finetuning_sessions.models.SessionList + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.SessionList] = kwargs.pop("cls", None) + + _request = build_sessions_list_request( + foundry_features=foundry_features, + api_version=api_version, + limit=limit, + offset=offset, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + if _stream: + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + else: + deserialized = _deserialize(_models.SessionList, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def get( + self, + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> _models.Session: + """Get a fine-tuning session. + + Get information about a specific fine-tuning session. + + :param session_id: Required. + :type session_id: str + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: Session. The Session is compatible with MutableMapping + :rtype: ~azure.ai.finetuning_sessions.models.Session + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.Session] = kwargs.pop("cls", None) + + _request = build_sessions_get_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + if _stream: + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + else: + deserialized = _deserialize(_models.Session, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def _unload_initial( + self, + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> AsyncIterator[bytes]: + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) + + _request = build_sessions_unload_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = True + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + response_headers = {} + response_headers["Operation-Location"] = self._deserialize("str", response.headers.get("Operation-Location")) + + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def begin_unload( + self, + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Unload a fine-tuning session. + + Unload a session from the GPU engine, freeing memory. LoRA weights are lost; save a checkpoint + before calling this. + + :param session_id: Required. + :type session_id: str + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.OperationResult] = kwargs.pop("cls", None) + polling: Union[bool, AsyncPollingMethod] = kwargs.pop("polling", True) + lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) + cont_token: Optional[str] = kwargs.pop("continuation_token", None) + if cont_token is None: + raw_result = await self._unload_initial( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + cls=lambda x, y, z: x, + headers=_headers, + params=_params, + **kwargs + ) + await raw_result.http_response.read() # type: ignore + kwargs.pop("error_map", None) + + def get_long_running_output(pipeline_response): + response_headers = {} + response = pipeline_response.http_response + response_headers["Operation-Location"] = self._deserialize( + "str", response.headers.get("Operation-Location") + ) + + deserialized = _deserialize(_models.OperationResult, response.json().get("result", {})) + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + return deserialized + + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + if polling is True: + polling_method: AsyncPollingMethod = cast( + AsyncPollingMethod, + AsyncLROBasePolling(lro_delay, path_format_arguments=path_format_arguments, **kwargs), + ) + elif polling is False: + polling_method = cast(AsyncPollingMethod, AsyncNoPolling()) + else: + polling_method = polling + if cont_token: + return AsyncLROPoller[_models.OperationResult].from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output, + ) + return AsyncLROPoller[_models.OperationResult]( + self._client, raw_result, get_long_running_output, polling_method # type: ignore + ) + + @distributed_trace_async + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def heartbeat( + self, + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> _models.HeartbeatResponse: + """Session heartbeat. + + Heartbeat — refresh an active session to prevent idle expiry. The SDK sends this automatically + every 30 seconds. Returns 404 if the session has already expired. + + :param session_id: Required. + :type session_id: str + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: HeartbeatResponse. The HeartbeatResponse is compatible with MutableMapping + :rtype: ~azure.ai.finetuning_sessions.models.HeartbeatResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.HeartbeatResponse] = kwargs.pop("cls", None) + + _request = build_sessions_heartbeat_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + if _stream: + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + else: + deserialized = _deserialize(_models.HeartbeatResponse, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + +class TrainingOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~azure.ai.finetuning_sessions.aio.FineTuningSessionClient`'s + :attr:`training` attribute. + """ + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: FineTuningSessionClientConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def _forward_backward_initial( + self, + session_id: str, + body: Union[_models.ForwardBackwardRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> AsyncIterator[bytes]: + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_training_forward_backward_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = True + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + response_headers = {} + response_headers["Operation-Location"] = self._deserialize("str", response.headers.get("Operation-Location")) + + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + async def begin_forward_backward( + self, + session_id: str, + body: _models.ForwardBackwardRequest, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Forward and backward pass. + + Submit a mini-batch for a combined forward + backward pass. + + Gradients accumulate until an optimizer step is issued. + Poll the returned operation URL for ``ForwardBackwardResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: ~azure.ai.finetuning_sessions.models.ForwardBackwardRequest + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def begin_forward_backward( + self, + session_id: str, + body: JSON, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Forward and backward pass. + + Submit a mini-batch for a combined forward + backward pass. + + Gradients accumulate until an optimizer step is issued. + Poll the returned operation URL for ``ForwardBackwardResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: JSON + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def begin_forward_backward( + self, + session_id: str, + body: IO[bytes], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Forward and backward pass. + + Submit a mini-batch for a combined forward + backward pass. + + Gradients accumulate until an optimizer step is issued. + Poll the returned operation URL for ``ForwardBackwardResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def begin_forward_backward( + self, + session_id: str, + body: Union[_models.ForwardBackwardRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Forward and backward pass. + + Submit a mini-batch for a combined forward + backward pass. + + Gradients accumulate until an optimizer step is issued. + Poll the returned operation URL for ``ForwardBackwardResult``. + + :param session_id: Required. + :type session_id: str + :param body: Is one of the following types: ForwardBackwardRequest, JSON, IO[bytes] Required. + :type body: ~azure.ai.finetuning_sessions.models.ForwardBackwardRequest or JSON or IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.OperationResult] = kwargs.pop("cls", None) + polling: Union[bool, AsyncPollingMethod] = kwargs.pop("polling", True) + lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) + cont_token: Optional[str] = kwargs.pop("continuation_token", None) + if cont_token is None: + raw_result = await self._forward_backward_initial( + session_id=session_id, + body=body, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + cls=lambda x, y, z: x, + headers=_headers, + params=_params, + **kwargs + ) + await raw_result.http_response.read() # type: ignore + kwargs.pop("error_map", None) + + def get_long_running_output(pipeline_response): + response_headers = {} + response = pipeline_response.http_response + response_headers["Operation-Location"] = self._deserialize( + "str", response.headers.get("Operation-Location") + ) + + deserialized = _deserialize(_models.OperationResult, response.json().get("result", {})) + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + return deserialized + + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + if polling is True: + polling_method: AsyncPollingMethod = cast( + AsyncPollingMethod, + AsyncLROBasePolling(lro_delay, path_format_arguments=path_format_arguments, **kwargs), + ) + elif polling is False: + polling_method = cast(AsyncPollingMethod, AsyncNoPolling()) + else: + polling_method = polling + if cont_token: + return AsyncLROPoller[_models.OperationResult].from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output, + ) + return AsyncLROPoller[_models.OperationResult]( + self._client, raw_result, get_long_running_output, polling_method # type: ignore + ) + + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def _optim_step_initial( + self, + session_id: str, + body: Union[_models.OptimStepRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> AsyncIterator[bytes]: + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_training_optim_step_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = True + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + response_headers = {} + response_headers["Operation-Location"] = self._deserialize("str", response.headers.get("Operation-Location")) + + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + async def begin_optim_step( + self, + session_id: str, + body: _models.OptimStepRequest, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Optimizer step. + + Apply accumulated gradients to the LoRA weights using the Adam optimizer. + + Poll the returned operation URL for ``OptimStepResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: ~azure.ai.finetuning_sessions.models.OptimStepRequest + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def begin_optim_step( + self, + session_id: str, + body: JSON, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Optimizer step. + + Apply accumulated gradients to the LoRA weights using the Adam optimizer. + + Poll the returned operation URL for ``OptimStepResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: JSON + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def begin_optim_step( + self, + session_id: str, + body: IO[bytes], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Optimizer step. + + Apply accumulated gradients to the LoRA weights using the Adam optimizer. + + Poll the returned operation URL for ``OptimStepResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def begin_optim_step( + self, + session_id: str, + body: Union[_models.OptimStepRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Optimizer step. + + Apply accumulated gradients to the LoRA weights using the Adam optimizer. + + Poll the returned operation URL for ``OptimStepResult``. + + :param session_id: Required. + :type session_id: str + :param body: Is one of the following types: OptimStepRequest, JSON, IO[bytes] Required. + :type body: ~azure.ai.finetuning_sessions.models.OptimStepRequest or JSON or IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.OperationResult] = kwargs.pop("cls", None) + polling: Union[bool, AsyncPollingMethod] = kwargs.pop("polling", True) + lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) + cont_token: Optional[str] = kwargs.pop("continuation_token", None) + if cont_token is None: + raw_result = await self._optim_step_initial( + session_id=session_id, + body=body, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + cls=lambda x, y, z: x, + headers=_headers, + params=_params, + **kwargs + ) + await raw_result.http_response.read() # type: ignore + kwargs.pop("error_map", None) + + def get_long_running_output(pipeline_response): + response_headers = {} + response = pipeline_response.http_response + response_headers["Operation-Location"] = self._deserialize( + "str", response.headers.get("Operation-Location") + ) + + deserialized = _deserialize(_models.OperationResult, response.json().get("result", {})) + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + return deserialized + + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + if polling is True: + polling_method: AsyncPollingMethod = cast( + AsyncPollingMethod, + AsyncLROBasePolling(lro_delay, path_format_arguments=path_format_arguments, **kwargs), + ) + elif polling is False: + polling_method = cast(AsyncPollingMethod, AsyncNoPolling()) + else: + polling_method = polling + if cont_token: + return AsyncLROPoller[_models.OperationResult].from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output, + ) + return AsyncLROPoller[_models.OperationResult]( + self._client, raw_result, get_long_running_output, polling_method # type: ignore + ) + + +class CheckpointsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~azure.ai.finetuning_sessions.aio.FineTuningSessionClient`'s + :attr:`checkpoints` attribute. + """ + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: FineTuningSessionClientConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def _save_initial( + self, + session_id: str, + body: Union[_models.SaveCheckpointRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> AsyncIterator[bytes]: + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_checkpoints_save_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = True + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + response_headers = {} + response_headers["Operation-Location"] = self._deserialize("str", response.headers.get("Operation-Location")) + + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + async def begin_save( + self, + session_id: str, + body: _models.SaveCheckpointRequest, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Save a training checkpoint. + + Save a training checkpoint (LoRA weights + optimizer state) to blob storage. + + Poll the returned operation URL for ``SaveCheckpointResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: ~azure.ai.finetuning_sessions.models.SaveCheckpointRequest + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def begin_save( + self, + session_id: str, + body: JSON, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Save a training checkpoint. + + Save a training checkpoint (LoRA weights + optimizer state) to blob storage. + + Poll the returned operation URL for ``SaveCheckpointResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: JSON + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def begin_save( + self, + session_id: str, + body: IO[bytes], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Save a training checkpoint. + + Save a training checkpoint (LoRA weights + optimizer state) to blob storage. + + Poll the returned operation URL for ``SaveCheckpointResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def begin_save( + self, + session_id: str, + body: Union[_models.SaveCheckpointRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Save a training checkpoint. + + Save a training checkpoint (LoRA weights + optimizer state) to blob storage. + + Poll the returned operation URL for ``SaveCheckpointResult``. + + :param session_id: Required. + :type session_id: str + :param body: Is one of the following types: SaveCheckpointRequest, JSON, IO[bytes] Required. + :type body: ~azure.ai.finetuning_sessions.models.SaveCheckpointRequest or JSON or IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.OperationResult] = kwargs.pop("cls", None) + polling: Union[bool, AsyncPollingMethod] = kwargs.pop("polling", True) + lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) + cont_token: Optional[str] = kwargs.pop("continuation_token", None) + if cont_token is None: + raw_result = await self._save_initial( + session_id=session_id, + body=body, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + cls=lambda x, y, z: x, + headers=_headers, + params=_params, + **kwargs + ) + await raw_result.http_response.read() # type: ignore + kwargs.pop("error_map", None) + + def get_long_running_output(pipeline_response): + response_headers = {} + response = pipeline_response.http_response + response_headers["Operation-Location"] = self._deserialize( + "str", response.headers.get("Operation-Location") + ) + + deserialized = _deserialize(_models.OperationResult, response.json().get("result", {})) + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + return deserialized + + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + if polling is True: + polling_method: AsyncPollingMethod = cast( + AsyncPollingMethod, + AsyncLROBasePolling(lro_delay, path_format_arguments=path_format_arguments, **kwargs), + ) + elif polling is False: + polling_method = cast(AsyncPollingMethod, AsyncNoPolling()) + else: + polling_method = polling + if cont_token: + return AsyncLROPoller[_models.OperationResult].from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output, + ) + return AsyncLROPoller[_models.OperationResult]( + self._client, raw_result, get_long_running_output, polling_method # type: ignore + ) + + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def _save_sampler_weights_initial( + self, + session_id: str, + body: Union[_models.SaveSamplerWeightsRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> AsyncIterator[bytes]: + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_checkpoints_save_sampler_weights_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = True + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + response_headers = {} + response_headers["Operation-Location"] = self._deserialize("str", response.headers.get("Operation-Location")) + + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + async def begin_save_sampler_weights( + self, + session_id: str, + body: _models.SaveSamplerWeightsRequest, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Save sampler weights. + + Save sampler-compatible weights (no optimizer state) for generation. + + Poll the returned operation URL for ``SaveSamplerWeightsResult``. + The result contains ``sampling_session_id`` — pass this to subsequent sample calls. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: ~azure.ai.finetuning_sessions.models.SaveSamplerWeightsRequest + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def begin_save_sampler_weights( + self, + session_id: str, + body: JSON, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Save sampler weights. + + Save sampler-compatible weights (no optimizer state) for generation. + + Poll the returned operation URL for ``SaveSamplerWeightsResult``. + The result contains ``sampling_session_id`` — pass this to subsequent sample calls. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: JSON + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def begin_save_sampler_weights( + self, + session_id: str, + body: IO[bytes], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Save sampler weights. + + Save sampler-compatible weights (no optimizer state) for generation. + + Poll the returned operation URL for ``SaveSamplerWeightsResult``. + The result contains ``sampling_session_id`` — pass this to subsequent sample calls. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def begin_save_sampler_weights( + self, + session_id: str, + body: Union[_models.SaveSamplerWeightsRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Save sampler weights. + + Save sampler-compatible weights (no optimizer state) for generation. + + Poll the returned operation URL for ``SaveSamplerWeightsResult``. + The result contains ``sampling_session_id`` — pass this to subsequent sample calls. + + :param session_id: Required. + :type session_id: str + :param body: Is one of the following types: SaveSamplerWeightsRequest, JSON, IO[bytes] + Required. + :type body: ~azure.ai.finetuning_sessions.models.SaveSamplerWeightsRequest or JSON or IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.OperationResult] = kwargs.pop("cls", None) + polling: Union[bool, AsyncPollingMethod] = kwargs.pop("polling", True) + lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) + cont_token: Optional[str] = kwargs.pop("continuation_token", None) + if cont_token is None: + raw_result = await self._save_sampler_weights_initial( + session_id=session_id, + body=body, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + cls=lambda x, y, z: x, + headers=_headers, + params=_params, + **kwargs + ) + await raw_result.http_response.read() # type: ignore + kwargs.pop("error_map", None) + + def get_long_running_output(pipeline_response): + response_headers = {} + response = pipeline_response.http_response + response_headers["Operation-Location"] = self._deserialize( + "str", response.headers.get("Operation-Location") + ) + + deserialized = _deserialize(_models.OperationResult, response.json().get("result", {})) + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + return deserialized + + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + if polling is True: + polling_method: AsyncPollingMethod = cast( + AsyncPollingMethod, + AsyncLROBasePolling(lro_delay, path_format_arguments=path_format_arguments, **kwargs), + ) + elif polling is False: + polling_method = cast(AsyncPollingMethod, AsyncNoPolling()) + else: + polling_method = polling + if cont_token: + return AsyncLROPoller[_models.OperationResult].from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output, + ) + return AsyncLROPoller[_models.OperationResult]( + self._client, raw_result, get_long_running_output, polling_method # type: ignore + ) + + @distributed_trace_async + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def list( + self, + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> _models.CheckpointList: + """List checkpoints. + + List all checkpoints (training and sampler) for this session. + + :param session_id: Required. + :type session_id: str + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: CheckpointList. The CheckpointList is compatible with MutableMapping + :rtype: ~azure.ai.finetuning_sessions.models.CheckpointList + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.CheckpointList] = kwargs.pop("cls", None) + + _request = build_checkpoints_list_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + if _stream: + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + else: + deserialized = _deserialize(_models.CheckpointList, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def get( + self, + session_id: str, + checkpoint_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> _models.CheckpointInfo: + """Get checkpoint info. + + Get metadata for a specific checkpoint. + + :param session_id: Required. + :type session_id: str + :param checkpoint_id: Required. + :type checkpoint_id: str + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: CheckpointInfo. The CheckpointInfo is compatible with MutableMapping + :rtype: ~azure.ai.finetuning_sessions.models.CheckpointInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.CheckpointInfo] = kwargs.pop("cls", None) + + _request = build_checkpoints_get_request( + session_id=session_id, + checkpoint_id=checkpoint_id, + foundry_features=foundry_features, + api_version=api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + if _stream: + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + else: + deserialized = _deserialize(_models.CheckpointInfo, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + +class SamplingOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~azure.ai.finetuning_sessions.aio.FineTuningSessionClient`'s + :attr:`sampling` attribute. + """ + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: FineTuningSessionClientConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def _sample_initial( + self, + session_id: str, + body: Union[_models.SampleRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> AsyncIterator[bytes]: + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_sampling_sample_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = True + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [202]: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + response_headers = {} + response_headers["Operation-Location"] = self._deserialize("str", response.headers.get("Operation-Location")) + + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + async def begin_sample( + self, + session_id: str, + body: _models.SampleRequest, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Generate completions. + + Generate one or more completions using the session's current LoRA weights. + + Requires a prior ``checkpoint_sample`` call to push weights to the sampler. + Pass ``sampling_session_id`` and ``seq_id`` from that call in the request body. + + Poll the returned operation URL for ``FineTuningSampleResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: ~azure.ai.finetuning_sessions.models.SampleRequest + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def begin_sample( + self, + session_id: str, + body: JSON, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Generate completions. + + Generate one or more completions using the session's current LoRA weights. + + Requires a prior ``checkpoint_sample`` call to push weights to the sampler. + Pass ``sampling_session_id`` and ``seq_id`` from that call in the request body. + + Poll the returned operation URL for ``FineTuningSampleResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: JSON + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def begin_sample( + self, + session_id: str, + body: IO[bytes], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Generate completions. + + Generate one or more completions using the session's current LoRA weights. + + Requires a prior ``checkpoint_sample`` call to push weights to the sampler. + Pass ``sampling_session_id`` and ``seq_id`` from that call in the request body. + + Poll the returned operation URL for ``FineTuningSampleResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def begin_sample( + self, + session_id: str, + body: Union[_models.SampleRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> AsyncLROPoller[_models.OperationResult]: + """Generate completions. + + Generate one or more completions using the session's current LoRA weights. + + Requires a prior ``checkpoint_sample`` call to push weights to the sampler. + Pass ``sampling_session_id`` and ``seq_id`` from that call in the request body. + + Poll the returned operation URL for ``FineTuningSampleResult``. + + :param session_id: Required. + :type session_id: str + :param body: Is one of the following types: SampleRequest, JSON, IO[bytes] Required. + :type body: ~azure.ai.finetuning_sessions.models.SampleRequest or JSON or IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: An instance of AsyncLROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: + ~azure.core.polling.AsyncLROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.OperationResult] = kwargs.pop("cls", None) + polling: Union[bool, AsyncPollingMethod] = kwargs.pop("polling", True) + lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) + cont_token: Optional[str] = kwargs.pop("continuation_token", None) + if cont_token is None: + raw_result = await self._sample_initial( + session_id=session_id, + body=body, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + cls=lambda x, y, z: x, + headers=_headers, + params=_params, + **kwargs + ) + await raw_result.http_response.read() # type: ignore + kwargs.pop("error_map", None) + + def get_long_running_output(pipeline_response): + response_headers = {} + response = pipeline_response.http_response + response_headers["Operation-Location"] = self._deserialize( + "str", response.headers.get("Operation-Location") + ) + + deserialized = _deserialize(_models.OperationResult, response.json().get("result", {})) + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + return deserialized + + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + if polling is True: + polling_method: AsyncPollingMethod = cast( + AsyncPollingMethod, + AsyncLROBasePolling(lro_delay, path_format_arguments=path_format_arguments, **kwargs), + ) + elif polling is False: + polling_method = cast(AsyncPollingMethod, AsyncNoPolling()) + else: + polling_method = polling + if cont_token: + return AsyncLROPoller[_models.OperationResult].from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output, + ) + return AsyncLROPoller[_models.OperationResult]( + self._client, raw_result, get_long_running_output, polling_method # type: ignore + ) + + +class Operations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~azure.ai.finetuning_sessions.aio.FineTuningSessionClient`'s + :attr:`operations` attribute. + """ + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: FineTuningSessionClientConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @distributed_trace_async + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + async def get( + self, + session_id: str, + operation_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> _models.OperationResult: + """Poll operation status. + + Poll for the result of an async fine-tuning operation (Azure LRO). + + Returns ``status: "running"`` while in progress, ``"succeeded"`` on completion, + or ``"failed"`` on error. When succeeded, the body contains the typed result + (``ForwardBackwardOperationResult``, ``OptimStepOperationResult``, etc.) + discriminated by ``type``. + + :param session_id: Required. + :type session_id: str + :param operation_id: Required. + :type operation_id: str + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: OperationResult. The OperationResult is compatible with MutableMapping + :rtype: ~azure.ai.finetuning_sessions.models.OperationResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.OperationResult] = kwargs.pop("cls", None) + + _request = build_operations_get_request( + session_id=session_id, + operation_id=operation_id, + foundry_features=foundry_features, + api_version=api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + await response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + if _stream: + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + else: + deserialized = _deserialize(_models.OperationResult, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/operations/_patch.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/operations/_patch.py new file mode 100644 index 000000000000..87676c65a8f0 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/aio/operations/_patch.py @@ -0,0 +1,21 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------- +"""Customize generated code here. + +Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize +""" + + +__all__: list[str] = [] # Add all objects you want publicly available to users at this package level + + +def patch_sdk(): + """Do not remove from this file. + + `patch_sdk` is a last resort escape hatch that allows you to do customizations + you can't accomplish using the techniques described in + https://aka.ms/azsdk/python/dpcodegen/python/customize + """ diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/models/__init__.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/models/__init__.py new file mode 100644 index 000000000000..4dba4f1f078a --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/models/__init__.py @@ -0,0 +1,114 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import + + +from ._models import ( # type: ignore + AdamParams, + ApiError, + ApiErrorResponse, + Checkpoint, + CheckpointInfo, + CheckpointList, + CreateSessionRequest, + Cursor, + Datum, + ForwardBackwardInput, + ForwardBackwardOperationResult, + ForwardBackwardRequest, + ForwardInput, + ForwardRequest, + HeartbeatResponse, + LoRAConfig, + LossFnConfig, + LossFnInputs, + ModelInput, + ModelInputChunk, + OperationResult, + OptimStepOperationResult, + OptimStepRequest, + SampleOperationResult, + SampleRequest, + SampledSequence, + SamplingParams, + SaveCheckpointOperationResult, + SaveCheckpointRequest, + SaveSamplerWeightsOperationResult, + SaveSamplerWeightsRequest, + Session, + SessionList, + SessionModelData, + SessionSummary, + TensorData, +) + +from ._enums import ( # type: ignore + CheckpointType, + FoundryFeaturesOptInKeys, + LossFn, + OperationStatus, + OperationType, + SessionStatus, + SessionType, +) +from ._patch import __all__ as _patch_all +from ._patch import * +from ._patch import patch_sdk as _patch_sdk + +__all__ = [ + "AdamParams", + "ApiError", + "ApiErrorResponse", + "Checkpoint", + "CheckpointInfo", + "CheckpointList", + "CreateSessionRequest", + "Cursor", + "Datum", + "ForwardBackwardInput", + "ForwardBackwardOperationResult", + "ForwardBackwardRequest", + "ForwardInput", + "ForwardRequest", + "HeartbeatResponse", + "LoRAConfig", + "LossFnConfig", + "LossFnInputs", + "ModelInput", + "ModelInputChunk", + "OperationResult", + "OptimStepOperationResult", + "OptimStepRequest", + "SampleOperationResult", + "SampleRequest", + "SampledSequence", + "SamplingParams", + "SaveCheckpointOperationResult", + "SaveCheckpointRequest", + "SaveSamplerWeightsOperationResult", + "SaveSamplerWeightsRequest", + "Session", + "SessionList", + "SessionModelData", + "SessionSummary", + "TensorData", + "CheckpointType", + "FoundryFeaturesOptInKeys", + "LossFn", + "OperationStatus", + "OperationType", + "SessionStatus", + "SessionType", +] +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore +_patch_sdk() diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/models/_enums.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/models/_enums.py new file mode 100644 index 000000000000..ecc09f73674f --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/models/_enums.py @@ -0,0 +1,102 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +from enum import Enum +from azure.core import CaseInsensitiveEnumMeta + + +class CheckpointType(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Fine-tuning checkpoint type.""" + + TRAINING = "training" + """Full training checkpoint (optimizer state + LoRA weights).""" + SAMPLER = "sampler" + """Sampler-compatible weights snapshot (no optimizer state).""" + + +class FoundryFeaturesOptInKeys(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Type of FoundryFeaturesOptInKeys.""" + + EVALUATIONS_V1_PREVIEW = "Evaluations=V1Preview" + """EVALUATIONS_V1_PREVIEW.""" + SCHEDULES_V1_PREVIEW = "Schedules=V1Preview" + """SCHEDULES_V1_PREVIEW.""" + RED_TEAMS_V1_PREVIEW = "RedTeams=V1Preview" + """RED_TEAMS_V1_PREVIEW.""" + INSIGHTS_V1_PREVIEW = "Insights=V1Preview" + """INSIGHTS_V1_PREVIEW.""" + MEMORY_STORES_V1_PREVIEW = "MemoryStores=V1Preview" + """MEMORY_STORES_V1_PREVIEW.""" + FINETUNING_SESSIONS_V1_PREVIEW = "FineTuningSessions=V1Preview" + """FINETUNING_SESSIONS_V1_PREVIEW.""" + + +class LossFn(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Supported loss functions for forward-backward passes.""" + + CROSS_ENTROPY = "cross_entropy" + """Standard language-model cross-entropy loss.""" + IMPORTANCE_SAMPLING = "importance_sampling" + """REINFORCE-style importance-sampled policy gradient loss.""" + PPO = "ppo" + """PPO (Proximal Policy Optimization) clipped surrogate loss.""" + CISPO = "cispo" + """Clipped Importance-Sampled Policy Optimization (CISPO) loss.""" + SAPO = "sapo" + """Soft-Advantage Policy Optimization (SAPO) loss.""" + + +class OperationStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Lifecycle status of an async fine-tuning operation — standard Azure LRO values.""" + + RUNNING = "running" + """Operation is in progress.""" + SUCCEEDED = "succeeded" + """Operation completed successfully.""" + FAILED = "failed" + """Operation failed.""" + + +class OperationType(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Discriminator values for async fine-tuning operation results.""" + + FORWARD_BACKWARD = "forward_backward" + """A forward-backward pass operation.""" + OPTIM_STEP = "optim_step" + """An optimizer step operation.""" + SAMPLE = "sample" + """A sampling operation.""" + SAVE_CHECKPOINT = "save_checkpoint" + """A training checkpoint save operation.""" + SAVE_SAMPLER_WEIGHTS = "save_sampler_weights" + """A sampler-weights save operation.""" + + +class SessionStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Lifecycle status of a fine-tuning session. + + Uses the canonical fine-tuning job vocabulary for consistency + across all API surfaces. + """ + + QUEUED = "queued" + """Session has been created and is waiting for the GPU engine.""" + RUNNING = "running" + """Session is loaded and actively processing requests.""" + SUCCEEDED = "succeeded" + """Session completed successfully.""" + FAILED = "failed" + """Session ended in a non-success terminal state (engine death, expiry, + or unrecoverable failure). Any in-memory weights are lost.""" + + +class SessionType(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Fine-tuning session type.""" + + TRAINING = "training" + """A training session for fine-tuning a model.""" diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/models/_models.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/models/_models.py new file mode 100644 index 000000000000..f5c9f7bf8422 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/models/_models.py @@ -0,0 +1,1406 @@ +# pylint: disable=line-too-long,useless-suppression,too-many-lines +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +# pylint: disable=useless-super-delegation + +import datetime +from typing import Any, Literal, Mapping, Optional, TYPE_CHECKING, Union, overload + +from .._utils.model_base import Model as _Model, rest_discriminator, rest_field +from ._enums import OperationType + +if TYPE_CHECKING: + from .. import _types, models as _models + + +class AdamParams(_Model): + """Adam optimizer hyper-parameters. + + :ivar learning_rate: Learning rate. Required. + :vartype learning_rate: float + :ivar beta1: Adam β₁ coefficient. Required. + :vartype beta1: float + :ivar beta2: Adam β₂ coefficient. Required. + :vartype beta2: float + :ivar eps: Adam ε (numerical stability floor). Required. + :vartype eps: float + :ivar weight_decay: L₂ weight-decay coefficient. Required. + :vartype weight_decay: float + """ + + learning_rate: float = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Learning rate. Required.""" + beta1: float = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Adam β₁ coefficient. Required.""" + beta2: float = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Adam β₂ coefficient. Required.""" + eps: float = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Adam ε (numerical stability floor). Required.""" + weight_decay: float = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """L₂ weight-decay coefficient. Required.""" + + @overload + def __init__( + self, + *, + learning_rate: float, + beta1: float, + beta2: float, + eps: float, + weight_decay: float, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ApiError(_Model): + """ApiError. + + :ivar code: Required. + :vartype code: str + :ivar message: Required. + :vartype message: str + :ivar param: + :vartype param: str + :ivar type: + :vartype type: str + :ivar details: + :vartype details: list[~azure.ai.finetuning_sessions.models.ApiError] + :ivar additional_info: + :vartype additional_info: dict[str, any] + :ivar debug_info: + :vartype debug_info: dict[str, any] + """ + + code: str = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + message: str = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + param: Optional[str] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + type: Optional[str] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + details: Optional[list["_models.ApiError"]] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + additional_info: Optional[dict[str, Any]] = rest_field( + name="additionalInfo", visibility=["read", "create", "update", "delete", "query"] + ) + debug_info: Optional[dict[str, Any]] = rest_field( + name="debugInfo", visibility=["read", "create", "update", "delete", "query"] + ) + + @overload + def __init__( + self, + *, + code: str, + message: str, + param: Optional[str] = None, + type: Optional[str] = None, + details: Optional[list["_models.ApiError"]] = None, + additional_info: Optional[dict[str, Any]] = None, + debug_info: Optional[dict[str, Any]] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ApiErrorResponse(_Model): + """Error response for API failures. + + :ivar error: Required. + :vartype error: ~azure.ai.finetuning_sessions.models.ApiError + """ + + error: "_models.ApiError" = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + + @overload + def __init__( + self, + *, + error: "_models.ApiError", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class Checkpoint(_Model): + """A checkpoint item returned in GET /fine_tuning/sessions/{sessionId}/checkpoints. + + :ivar checkpoint_id: Required. + :vartype checkpoint_id: str + :ivar checkpoint_type: Required. Known values are: "training" and "sampler". + :vartype checkpoint_type: str or ~azure.ai.finetuning_sessions.models.CheckpointType + :ivar time: Timestamp when the checkpoint was saved. Required. + :vartype time: ~datetime.datetime + """ + + checkpoint_id: str = rest_field(visibility=["read"]) + """Required.""" + checkpoint_type: Union[str, "_models.CheckpointType"] = rest_field(visibility=["read"]) + """Required. Known values are: \"training\" and \"sampler\".""" + time: datetime.datetime = rest_field(visibility=["read"], format="rfc3339") + """Timestamp when the checkpoint was saved. Required.""" + + +class CheckpointInfo(_Model): + """Detailed metadata for a single checkpoint. + + :ivar base_model: Required. + :vartype base_model: str + :ivar is_lora: Required. + :vartype is_lora: bool + :ivar lora_rank: + :vartype lora_rank: int + """ + + base_model: str = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + is_lora: bool = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + lora_rank: Optional[int] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + + @overload + def __init__( + self, + *, + base_model: str, + is_lora: bool, + lora_rank: Optional[int] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class CheckpointList(_Model): + """List of all checkpoints for a session (no pagination). + + :ivar checkpoints: Required. + :vartype checkpoints: list[~azure.ai.finetuning_sessions.models.Checkpoint] + """ + + checkpoints: list["_models.Checkpoint"] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + + @overload + def __init__( + self, + *, + checkpoints: list["_models.Checkpoint"], + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class CreateSessionRequest(_Model): + """Request body for POST /fine_tuning/sessions. + + :ivar type: The session type. Required. "training" + :vartype type: str or ~azure.ai.finetuning_sessions.models.SessionType + :ivar base_model: Required. + :vartype base_model: str + :ivar lora_config: LoRA adapter config. Rank is fixed server-side for v1; omit to use the + server default. + :vartype lora_config: ~azure.ai.finetuning_sessions.models.LoRAConfig + :ivar user_metadata: + :vartype user_metadata: dict[str, str] + """ + + type: Union[str, "_models.SessionType"] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """The session type. Required. \"training\"""" + base_model: str = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + lora_config: Optional["_models.LoRAConfig"] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """LoRA adapter config. Rank is fixed server-side for v1; omit to use the server default.""" + user_metadata: Optional[dict[str, str]] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + ejectable: Optional[bool] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Opt the session into idle hibernation. Default False.""" + + @overload + def __init__( + self, + *, + type: Union[str, "_models.SessionType"], + base_model: str, + lora_config: Optional["_models.LoRAConfig"] = None, + user_metadata: Optional[dict[str, str]] = None, + ejectable: Optional[bool] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class Cursor(_Model): + """Pagination cursor returned in list responses. + + :ivar offset: Zero-based index of the first item in the current page. Required. + :vartype offset: int + :ivar limit: Maximum number of items per page. Required. + :vartype limit: int + :ivar total_count: Total number of items across all pages. Required. + :vartype total_count: int + """ + + offset: int = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Zero-based index of the first item in the current page. Required.""" + limit: int = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Maximum number of items per page. Required.""" + total_count: int = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Total number of items across all pages. Required.""" + + @overload + def __init__( + self, + *, + offset: int, + limit: int, + total_count: int, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class Datum(_Model): + """A single training example. + + :ivar model_input: Token-ID input to the model. Required. + :vartype model_input: ~azure.ai.finetuning_sessions.models.ModelInput + :ivar loss_fn_inputs: Loss-function targets, aligned with model_input tokens. Required. + :vartype loss_fn_inputs: ~azure.ai.finetuning_sessions.models.LossFnInputs + """ + + model_input: "_models.ModelInput" = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Token-ID input to the model. Required.""" + loss_fn_inputs: "_models.LossFnInputs" = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Loss-function targets, aligned with model_input tokens. Required.""" + + @overload + def __init__( + self, + *, + model_input: "_models.ModelInput", + loss_fn_inputs: "_models.LossFnInputs", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ForwardBackwardInput(_Model): + """Inner payload for a forward-backward request. + + :ivar data: Required. + :vartype data: list[~azure.ai.finetuning_sessions.models.Datum] + :ivar loss_fn: Required. Known values are: "cross_entropy", "importance_sampling", "ppo", + "cispo", and "sapo". + :vartype loss_fn: str or ~azure.ai.finetuning_sessions.models.LossFn + :ivar loss_fn_config: + :vartype loss_fn_config: ~azure.ai.finetuning_sessions.models.LossFnConfig + """ + + data: list["_models.Datum"] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + loss_fn: Union[str, "_models.LossFn"] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required. Known values are: \"cross_entropy\", \"importance_sampling\", \"ppo\", \"cispo\", and + \"sapo\".""" + loss_fn_config: Optional["_models.LossFnConfig"] = rest_field( + visibility=["read", "create", "update", "delete", "query"] + ) + + @overload + def __init__( + self, + *, + data: list["_models.Datum"], + loss_fn: Union[str, "_models.LossFn"], + loss_fn_config: Optional["_models.LossFnConfig"] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class OperationResult(_Model): + """Discriminated union of all async operation results. Returned by GET + /fine_tuning/sessions/{sessionId}/operations/{operationId}. + + You probably want to use the sub-classes and not this class directly. Known sub-classes are: + ForwardBackwardOperationResult, OptimStepOperationResult, SampleOperationResult, + SaveCheckpointOperationResult, SaveSamplerWeightsOperationResult + + :ivar type: Required. Known values are: "forward_backward", "optim_step", "sample", + "save_checkpoint", and "save_sampler_weights". + :vartype type: str or ~azure.ai.finetuning_sessions.models.OperationType + :ivar operation_id: Required. + :vartype operation_id: str + :ivar status: Required. Known values are: "running", "succeeded", and "failed". + :vartype status: str or ~azure.ai.finetuning_sessions.models.OperationStatus + """ + + __mapping__: dict[str, _Model] = {} + type: str = rest_discriminator(name="type", visibility=["read", "create", "update", "delete", "query"]) + """Required. Known values are: \"forward_backward\", \"optim_step\", \"sample\", + \"save_checkpoint\", and \"save_sampler_weights\".""" + operation_id: str = rest_field(visibility=["read"]) + """Required.""" + status: Union[str, "_models.OperationStatus"] = rest_field(visibility=["read"]) + """Required. Known values are: \"running\", \"succeeded\", and \"failed\".""" + + @overload + def __init__( + self, + *, + type: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ForwardBackwardOperationResult(OperationResult, discriminator="forward_backward"): + """ForwardBackwardOperationResult. + + :ivar operation_id: Required. + :vartype operation_id: str + :ivar status: Required. Known values are: "running", "succeeded", and "failed". + :vartype status: str or ~azure.ai.finetuning_sessions.models.OperationStatus + :ivar type: Required. A forward-backward pass operation. + :vartype type: str or ~azure.ai.finetuning_sessions.models.FORWARD_BACKWARD + :ivar total_loss: Required. + :vartype total_loss: float + :ivar per_datum_logprobs: + :vartype per_datum_logprobs: list[~azure.ai.finetuning_sessions.models.TensorData] + :ivar metrics: + :vartype metrics: dict[str, float] + """ + + type: Literal[OperationType.FORWARD_BACKWARD] = rest_discriminator(name="type", visibility=["read", "create", "update", "delete", "query"]) # type: ignore + """Required. A forward-backward pass operation.""" + total_loss: float = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + per_datum_logprobs: Optional[list["_models.TensorData"]] = rest_field( + visibility=["read", "create", "update", "delete", "query"] + ) + metrics: Optional[dict[str, float]] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + + @overload + def __init__( + self, + *, + total_loss: float, + per_datum_logprobs: Optional[list["_models.TensorData"]] = None, + metrics: Optional[dict[str, float]] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.type = OperationType.FORWARD_BACKWARD # type: ignore + + +class ForwardBackwardRequest(_Model): + """Request body for POST /fine_tuning/sessions/{sessionId}/forward_backward. + + :ivar forward_backward_input: Required. + :vartype forward_backward_input: ~azure.ai.finetuning_sessions.models.ForwardBackwardInput + """ + + forward_backward_input: "_models.ForwardBackwardInput" = rest_field( + visibility=["read", "create", "update", "delete", "query"] + ) + """Required.""" + + @overload + def __init__( + self, + *, + forward_backward_input: "_models.ForwardBackwardInput", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ForwardInput(ForwardBackwardInput): + """Forward-only payload. + + Currently identical to ForwardBackwardInput. Exists as a named alias so + that the forward endpoint has its own type in the SDK, allowing the + contract to diverge later (e.g. making loss_fn optional for forward). + """ + + +class ForwardRequest(_Model): + """Request body for POST /fine_tuning/sessions/{sessionId}/forward. + + :ivar forward_input: Required. + :vartype forward_input: ~azure.ai.finetuning_sessions.models.ForwardInput + """ + + forward_input: "_models.ForwardInput" = rest_field( + visibility=["read", "create", "update", "delete", "query"] + ) + """Required.""" + + @overload + def __init__( + self, + *, + forward_input: "_models.ForwardInput", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class HeartbeatResponse(_Model): + """Response from POST /fine_tuning/sessions/{sessionId}/heartbeat. + + :ivar session_id: Required. + :vartype session_id: str + """ + + session_id: str = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + + @overload + def __init__( + self, + *, + session_id: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class LoRAConfig(_Model): + """LoRA adapter configuration. ``rank`` is fixed server-side for v1; omit to use the server + default. + + :ivar rank: Number of LoRA rank dimensions. + :vartype rank: int + :ivar alpha: LoRA scaling factor (effective scale = alpha / rank). Defaults to 32.0 server-side. + :vartype alpha: float + :ivar seed: Seed for LoRA weight initialisation. If omitted, the server picks a random seed and + echoes it back. + :vartype seed: int + """ + + rank: Optional[int] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Number of LoRA rank dimensions.""" + alpha: Optional[float] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """LoRA scaling factor (effective scale = alpha / rank). Defaults to 32.0 server-side.""" + seed: Optional[int] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Seed for LoRA weight initialisation. If omitted, the server picks a random seed and echoes it + back.""" + + @overload + def __init__( + self, + *, + rank: Optional[int] = None, + alpha: Optional[float] = None, + seed: Optional[int] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class LossFnConfig(_Model): + """Optional per-loss-function hyper-parameters. + + :ivar clip_low_threshold: PPO (Proximal Policy Optimization) / CISPO (Clipped + Importance-Sampled Policy Optimization): lower clip threshold. + :vartype clip_low_threshold: float + :ivar clip_high_threshold: PPO (Proximal Policy Optimization) / CISPO (Clipped + Importance-Sampled Policy Optimization): upper clip threshold. + :vartype clip_high_threshold: float + :ivar tau_pos: SAPO (Soft-Advantage Policy Optimization): positive advantage temperature. + :vartype tau_pos: float + :ivar tau_neg: SAPO (Soft-Advantage Policy Optimization): negative advantage temperature. + :vartype tau_neg: float + """ + + clip_low_threshold: Optional[float] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """PPO (Proximal Policy Optimization) / CISPO (Clipped Importance-Sampled Policy Optimization): + lower clip threshold.""" + clip_high_threshold: Optional[float] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """PPO (Proximal Policy Optimization) / CISPO (Clipped Importance-Sampled Policy Optimization): + upper clip threshold.""" + tau_pos: Optional[float] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """SAPO (Soft-Advantage Policy Optimization): positive advantage temperature.""" + tau_neg: Optional[float] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """SAPO (Soft-Advantage Policy Optimization): negative advantage temperature.""" + + @overload + def __init__( + self, + *, + clip_low_threshold: Optional[float] = None, + clip_high_threshold: Optional[float] = None, + tau_pos: Optional[float] = None, + tau_neg: Optional[float] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class LossFnInputs(_Model): + """Per-datum loss function inputs used in forward-backward. + + :ivar target_tokens: Target token ids (shifted by 1 relative to model input). Required. + :vartype target_tokens: ~azure.ai.finetuning_sessions.models.TensorData + :ivar weights: Per-token weights (0.0 = masked, 1.0 = counted). Required. + :vartype weights: ~azure.ai.finetuning_sessions.models.TensorData + :ivar advantages: Per-token advantage estimates (required for REINFORCE/PPO (Proximal Policy + Optimization)/CISPO (Clipped Importance-Sampled Policy Optimization)/SAPO (Soft-Advantage + Policy Optimization)). Omit or pass empty array for cross-entropy. + :vartype advantages: ~azure.ai.finetuning_sessions.models.TensorData + :ivar logprobs: Per-token reference log-probabilities for KL-divergence regularisation. Omit or + pass empty array to skip KL. + :vartype logprobs: ~azure.ai.finetuning_sessions.models.TensorData + """ + + target_tokens: "_models.TensorData" = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Target token ids (shifted by 1 relative to model input). Required.""" + weights: "_models.TensorData" = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Per-token weights (0.0 = masked, 1.0 = counted). Required.""" + advantages: Optional["_models.TensorData"] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Per-token advantage estimates (required for REINFORCE/PPO (Proximal Policy Optimization)/CISPO + (Clipped Importance-Sampled Policy Optimization)/SAPO (Soft-Advantage Policy Optimization)). + Omit or pass empty array for cross-entropy.""" + logprobs: Optional["_models.TensorData"] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Per-token reference log-probabilities for KL-divergence regularisation. Omit or pass empty + array to skip KL.""" + + @overload + def __init__( + self, + *, + target_tokens: "_models.TensorData", + weights: "_models.TensorData", + advantages: Optional["_models.TensorData"] = None, + logprobs: Optional["_models.TensorData"] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ModelInput(_Model): + """Full model input as one or more token-ID chunks. + + :ivar chunks: Ordered list of token-ID chunks that together form the complete model input. + Required. + :vartype chunks: list[~azure.ai.finetuning_sessions.models.ModelInputChunk] + """ + + chunks: list["_models.ModelInputChunk"] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Ordered list of token-ID chunks that together form the complete model input. Required.""" + + @overload + def __init__( + self, + *, + chunks: list["_models.ModelInputChunk"], + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ModelInputChunk(_Model): + """A contiguous block of token IDs forming part of a model input. + + :ivar tokens: Sequence of token IDs in this chunk. Required. + :vartype tokens: list[int] + """ + + tokens: list[int] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Sequence of token IDs in this chunk. Required.""" + + @overload + def __init__( + self, + *, + tokens: list[int], + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class OptimStepOperationResult(OperationResult, discriminator="optim_step"): + """OptimStepOperationResult. + + :ivar operation_id: Required. + :vartype operation_id: str + :ivar status: Required. Known values are: "running", "succeeded", and "failed". + :vartype status: str or ~azure.ai.finetuning_sessions.models.OperationStatus + :ivar type: Required. An optimizer step operation. + :vartype type: str or ~azure.ai.finetuning_sessions.models.OPTIM_STEP + :ivar grad_norm: Required. + :vartype grad_norm: float + :ivar step_count: Required. + :vartype step_count: int + :ivar metrics: + :vartype metrics: dict[str, float] + """ + + type: Literal[OperationType.OPTIM_STEP] = rest_discriminator(name="type", visibility=["read", "create", "update", "delete", "query"]) # type: ignore + """Required. An optimizer step operation.""" + grad_norm: float = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + step_count: int = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + metrics: Optional[dict[str, float]] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + + @overload + def __init__( + self, + *, + grad_norm: float, + step_count: int, + metrics: Optional[dict[str, float]] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.type = OperationType.OPTIM_STEP # type: ignore + + +class OptimStepRequest(_Model): + """Request body for POST /fine_tuning/sessions/{sessionId}/optim_step. + + :ivar adam_params: Required. + :vartype adam_params: ~azure.ai.finetuning_sessions.models.AdamParams + """ + + adam_params: "_models.AdamParams" = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + + @overload + def __init__( + self, + *, + adam_params: "_models.AdamParams", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class SampledSequence(_Model): + """A single sampled sequence. + + :ivar tokens: Required. + :vartype tokens: list[int] + :ivar text: Decoded text of the generated sequence. + :vartype text: str + :ivar logprobs: + :vartype logprobs: list[float] + :ivar prompt_logprobs: + :vartype prompt_logprobs: list[float] + """ + + tokens: list[int] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + text: Optional[str] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Decoded text of the generated sequence.""" + logprobs: Optional[list[float]] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + prompt_logprobs: Optional[list[float]] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + + @overload + def __init__( + self, + *, + tokens: list[int], + text: Optional[str] = None, + logprobs: Optional[list[float]] = None, + prompt_logprobs: Optional[list[float]] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class SampleOperationResult(OperationResult, discriminator="sample"): + """SampleOperationResult. + + :ivar operation_id: Required. + :vartype operation_id: str + :ivar status: Required. Known values are: "running", "succeeded", and "failed". + :vartype status: str or ~azure.ai.finetuning_sessions.models.OperationStatus + :ivar type: Required. A sampling operation. + :vartype type: str or ~azure.ai.finetuning_sessions.models.SAMPLE + :ivar sequences: Required. + :vartype sequences: list[~azure.ai.finetuning_sessions.models.SampledSequence] + """ + + type: Literal[OperationType.SAMPLE] = rest_discriminator(name="type", visibility=["read", "create", "update", "delete", "query"]) # type: ignore + """Required. A sampling operation.""" + sequences: list["_models.SampledSequence"] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + + @overload + def __init__( + self, + *, + sequences: list["_models.SampledSequence"], + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.type = OperationType.SAMPLE # type: ignore + + +class SampleRequest(_Model): + """Request body for POST /fine_tuning/sessions/{sessionId}/sample. + + :ivar num_samples: Number of independent completions to generate. Default 1. Required. + :vartype num_samples: int + :ivar prompt: Tokenised input prompt. Required. + :vartype prompt: ~azure.ai.finetuning_sessions.models.ModelInput + :ivar sampling_params: Required. + :vartype sampling_params: ~azure.ai.finetuning_sessions.models.SamplingParams + :ivar sampling_session_id: Sampling session ID from a prior save_sampler_weights call. + :vartype sampling_session_id: str + :ivar seq_id: Training step index; must match the seq_id used in save_sampler_weights. + :vartype seq_id: int + :ivar prompt_logprobs: If true, return per-token log-probabilities for the prompt tokens. + :vartype prompt_logprobs: bool + :ivar topk_prompt_logprobs: Number of top-k log-probabilities to return per prompt token. 0 = + none. Required. + :vartype topk_prompt_logprobs: int + """ + + num_samples: int = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Number of independent completions to generate. Default 1. Required.""" + prompt: "_models.ModelInput" = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Tokenised input prompt. Required.""" + sampling_params: "_models.SamplingParams" = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + sampling_session_id: Optional[str] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Sampling session ID from a prior save_sampler_weights call.""" + seq_id: Optional[int] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Training step index; must match the seq_id used in save_sampler_weights.""" + prompt_logprobs: Optional[bool] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """If true, return per-token log-probabilities for the prompt tokens.""" + topk_prompt_logprobs: int = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Number of top-k log-probabilities to return per prompt token. 0 = none. Required.""" + + @overload + def __init__( + self, + *, + num_samples: int, + prompt: "_models.ModelInput", + sampling_params: "_models.SamplingParams", + topk_prompt_logprobs: int, + sampling_session_id: Optional[str] = None, + seq_id: Optional[int] = None, + prompt_logprobs: Optional[bool] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class SamplingParams(_Model): + """Token-generation sampling parameters. + + :ivar max_tokens: Maximum tokens to generate. Required. + :vartype max_tokens: int + :ivar temperature: Softmax temperature. Default 1.0. Required. + :vartype temperature: float + :ivar top_p: Nucleus (top-p) probability mass. Default 1.0. Required. + :vartype top_p: float + :ivar top_k: Top-k candidates. -1 = disabled. Required. + :vartype top_k: int + :ivar seed: RNG seed for reproducible samples. Server chooses randomly if omitted. + :vartype seed: int + :ivar stop_criteria: Stop criteria: either stop_token_ids or stop_strings, not both. Is either + a [int] type or a [str] type. + :vartype stop_criteria: list[int] or list[str] + """ + + max_tokens: int = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Maximum tokens to generate. Required.""" + temperature: float = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Softmax temperature. Default 1.0. Required.""" + top_p: float = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Nucleus (top-p) probability mass. Default 1.0. Required.""" + top_k: int = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Top-k candidates. -1 = disabled. Required.""" + seed: Optional[int] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """RNG seed for reproducible samples. Server chooses randomly if omitted.""" + stop_criteria: Optional["_types.StopCriteria"] = rest_field( + visibility=["read", "create", "update", "delete", "query"] + ) + """Stop criteria: either stop_token_ids or stop_strings, not both. Is either a [int] type or a + [str] type.""" + + @overload + def __init__( + self, + *, + max_tokens: int, + temperature: float, + top_p: float, + top_k: int, + seed: Optional[int] = None, + stop_criteria: Optional["_types.StopCriteria"] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class SaveCheckpointOperationResult(OperationResult, discriminator="save_checkpoint"): + """SaveCheckpointOperationResult. + + :ivar operation_id: Required. + :vartype operation_id: str + :ivar status: Required. Known values are: "running", "succeeded", and "failed". + :vartype status: str or ~azure.ai.finetuning_sessions.models.OperationStatus + :ivar type: Required. A training checkpoint save operation. + :vartype type: str or ~azure.ai.finetuning_sessions.models.SAVE_CHECKPOINT + :ivar checkpoint_id: Required. + :vartype checkpoint_id: str + :ivar path: Required. + :vartype path: str + """ + + type: Literal[OperationType.SAVE_CHECKPOINT] = rest_discriminator(name="type", visibility=["read", "create", "update", "delete", "query"]) # type: ignore + """Required. A training checkpoint save operation.""" + checkpoint_id: str = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + path: str = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + + @overload + def __init__( + self, + *, + checkpoint_id: str, + path: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.type = OperationType.SAVE_CHECKPOINT # type: ignore + + +class SaveCheckpointRequest(_Model): + """Request body for POST /fine_tuning/sessions/{sessionId}/checkpoint. + + :ivar path: User-supplied checkpoint identifier. Alphanumeric plus underscores and hyphens; max + 255 characters. Required. + :vartype path: str + :ivar step_number: Training iteration this checkpoint corresponds to. Optional. + :vartype step_number: int + :ivar metrics: Per-step evaluation metrics at checkpoint time. Optional. + :vartype metrics: dict[str, any] + """ + + path: str = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """User-supplied checkpoint identifier. Alphanumeric plus underscores and hyphens; max 255 + characters. Required.""" + + step_number: Optional[int] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Training iteration this checkpoint corresponds to. Optional.""" + + metrics: Optional[dict[str, Any]] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Per-step evaluation metrics at checkpoint time. Optional.""" + + @overload + def __init__( + self, + *, + path: str, + step_number: Optional[int] = None, + metrics: Optional[dict[str, Any]] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class SaveSamplerWeightsOperationResult(OperationResult, discriminator="save_sampler_weights"): + """SaveSamplerWeightsOperationResult. + + :ivar operation_id: Required. + :vartype operation_id: str + :ivar status: Required. Known values are: "running", "succeeded", and "failed". + :vartype status: str or ~azure.ai.finetuning_sessions.models.OperationStatus + :ivar type: Required. A sampler-weights save operation. + :vartype type: str or ~azure.ai.finetuning_sessions.models.SAVE_SAMPLER_WEIGHTS + :ivar checkpoint_id: Required. + :vartype checkpoint_id: str + :ivar sampling_session_id: Required. + :vartype sampling_session_id: str + """ + + type: Literal[OperationType.SAVE_SAMPLER_WEIGHTS] = rest_discriminator(name="type", visibility=["read", "create", "update", "delete", "query"]) # type: ignore + """Required. A sampler-weights save operation.""" + checkpoint_id: str = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + sampling_session_id: str = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + + @overload + def __init__( + self, + *, + checkpoint_id: str, + sampling_session_id: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.type = OperationType.SAVE_SAMPLER_WEIGHTS # type: ignore + + +class SaveSamplerWeightsRequest(_Model): + """Request body for POST /fine_tuning/sessions/{sessionId}/checkpoint_sample. + + :ivar path: Optional explicit identifier for the sampler checkpoint. + :vartype path: str + :ivar sampling_session_seq_id: Ordinal of this sampling session within the training run. + :vartype sampling_session_seq_id: int + :ivar seq_id: Training step sequence number. + :vartype seq_id: int + """ + + path: Optional[str] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Optional explicit identifier for the sampler checkpoint.""" + sampling_session_seq_id: Optional[int] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Ordinal of this sampling session within the training run.""" + seq_id: Optional[int] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Training step sequence number.""" + + @overload + def __init__( + self, + *, + path: Optional[str] = None, + sampling_session_seq_id: Optional[int] = None, + seq_id: Optional[int] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class Session(_Model): + """Response from GET /fine_tuning/sessions/{sessionId}. + + :ivar session_id: Unique identifier for this fine-tuning session. Required. + :vartype session_id: str + :ivar type: The session type. Required. "training" + :vartype type: str or ~azure.ai.finetuning_sessions.models.SessionType + :ivar status: Current lifecycle status of the session. Required. Known values are: "queued", + "running", "succeeded", and "failed". + :vartype status: str or ~azure.ai.finetuning_sessions.models.SessionStatus + :ivar model_data: Model and adapter configuration associated with this session. Required. + :vartype model_data: ~azure.ai.finetuning_sessions.models.SessionModelData + """ + + session_id: str = rest_field(visibility=["read"]) + """Unique identifier for this fine-tuning session. Required.""" + type: Union[str, "_models.SessionType"] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """The session type. Required. \"training\"""" + status: Union[str, "_models.SessionStatus"] = rest_field(visibility=["read"]) + """Current lifecycle status of the session. Required. Known values are: \"queued\", \"running\", + \"succeeded\", and \"failed\".""" + model_data: "_models.SessionModelData" = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Model and adapter configuration associated with this session. Required.""" + + @overload + def __init__( + self, + *, + type: Union[str, "_models.SessionType"], + model_data: "_models.SessionModelData", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class SessionList(_Model): + """Paginated list of fine-tuning sessions. + + :ivar data: Required. + :vartype data: list[~azure.ai.finetuning_sessions.models.SessionSummary] + :ivar cursor: Required. + :vartype cursor: ~azure.ai.finetuning_sessions.models.Cursor + """ + + data: list["_models.SessionSummary"] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + cursor: "_models.Cursor" = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + + @overload + def __init__( + self, + *, + data: list["_models.SessionSummary"], + cursor: "_models.Cursor", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class SessionModelData(_Model): + """Nested model sub-object within a session response. + + :ivar base_model: Required. + :vartype base_model: str + :ivar lora_config: + :vartype lora_config: ~azure.ai.finetuning_sessions.models.LoRAConfig + :ivar model_name: + :vartype model_name: str + """ + + base_model: str = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Required.""" + lora_config: Optional["_models.LoRAConfig"] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + model_name: Optional[str] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + + @overload + def __init__( + self, + *, + base_model: str, + lora_config: Optional["_models.LoRAConfig"] = None, + model_name: Optional[str] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class SessionSummary(_Model): + """Summary item returned in the paginated GET /fine_tuning/sessions list. + + :ivar session_id: Unique identifier for this fine-tuning session. Required. + :vartype session_id: str + :ivar base_model: Base model used for this fine-tuning session. Required. + :vartype base_model: str + :ivar status: Current lifecycle status of the session. Required. Known values are: "queued", + "running", "succeeded", and "failed". + :vartype status: str or ~azure.ai.finetuning_sessions.models.SessionStatus + :ivar is_lora: Whether the session uses a LoRA adapter. Required. + :vartype is_lora: bool + :ivar lora_rank: LoRA rank, if applicable. + :vartype lora_rank: int + :ivar corrupted: Indicates whether the session state is corrupted and cannot be resumed. + Required. + :vartype corrupted: bool + :ivar last_request_time: Timestamp of the most recent request made against this session. + Required. + :vartype last_request_time: ~datetime.datetime + """ + + session_id: str = rest_field(visibility=["read"]) + """Unique identifier for this fine-tuning session. Required.""" + base_model: str = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Base model used for this fine-tuning session. Required.""" + status: Union[str, "_models.SessionStatus"] = rest_field(visibility=["read"]) + """Current lifecycle status of the session. Required. Known values are: \"queued\", \"running\", + \"succeeded\", and \"failed\".""" + is_lora: bool = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Whether the session uses a LoRA adapter. Required.""" + lora_rank: Optional[int] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """LoRA rank, if applicable.""" + corrupted: bool = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """Indicates whether the session state is corrupted and cannot be resumed. Required.""" + last_request_time: datetime.datetime = rest_field(visibility=["read"], format="rfc3339") + """Timestamp of the most recent request made against this session. Required.""" + + @overload + def __init__( + self, + *, + base_model: str, + is_lora: bool, + corrupted: bool, + lora_rank: Optional[int] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class TensorData(_Model): + """A 1-D array of floating-point values serialised for the wire. + + :ivar data: The floating-point values of the tensor. Required. + :vartype data: list[float] + """ + + data: list[float] = rest_field(visibility=["read", "create", "update", "delete", "query"]) + """The floating-point values of the tensor. Required.""" + + @overload + def __init__( + self, + *, + data: list[float], + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/models/_patch.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/models/_patch.py new file mode 100644 index 000000000000..24835e8ce646 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/models/_patch.py @@ -0,0 +1,66 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------- +"""Customize generated code here. + +Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize +""" + +from typing import Any, Mapping, Optional, overload + +from .._utils.model_base import Model as _Model, rest_field + + +class FromCheckpoint(_Model): + """Identifies a saved training checkpoint to bootstrap a new session from. + + When passed to :meth:`~azure.ai.finetuning_sessions.FineTuningSession.create`, + the new session's LoRA weights, optimizer state, and scheduler step are all + initialised from the referenced checkpoint (continual fine-tuning). + + :ivar source_session_id: The ``model_`` of the session that saved + the checkpoint. + :vartype source_session_id: str + :ivar checkpoint_id: Name of the checkpoint within the source session. + :vartype checkpoint_id: str + """ + + source_session_id: str = rest_field() + """The ``model_`` of the session that saved the checkpoint.""" + + checkpoint_id: str = rest_field() + """Name of the checkpoint within the source session.""" + + @overload + def __init__( + self, + *, + source_session_id: str, + checkpoint_id: str, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +__all__: list[str] = [ + "FromCheckpoint", +] + + +def patch_sdk(): + """Do not remove from this file. + + `patch_sdk` is a last resort escape hatch that allows you to do customizations + you can't accomplish using the techniques described in + https://aka.ms/azsdk/python/dpcodegen/python/customize + """ diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/operations/__init__.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/operations/__init__.py new file mode 100644 index 000000000000..8f152cec5e7e --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/operations/__init__.py @@ -0,0 +1,33 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import + +from ._operations import SessionsOperations # type: ignore +from ._operations import TrainingOperations # type: ignore +from ._operations import CheckpointsOperations # type: ignore +from ._operations import SamplingOperations # type: ignore +from ._operations import Operations # type: ignore + +from ._patch import __all__ as _patch_all +from ._patch import * +from ._patch import patch_sdk as _patch_sdk + +__all__ = [ + "SessionsOperations", + "TrainingOperations", + "CheckpointsOperations", + "SamplingOperations", + "Operations", +] +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore +_patch_sdk() diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/operations/_operations.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/operations/_operations.py new file mode 100644 index 000000000000..bf9ee2182f96 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/operations/_operations.py @@ -0,0 +1,2977 @@ +# pylint: disable=too-many-lines +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- +from collections.abc import MutableMapping +from io import IOBase +import json +from typing import Any, Callable, IO, Iterator, Literal, Optional, TypeVar, Union, cast, overload + +from azure.core import PipelineClient +from azure.core.exceptions import ( + ClientAuthenticationError, + HttpResponseError, + ResourceExistsError, + ResourceNotFoundError, + ResourceNotModifiedError, + StreamClosedError, + StreamConsumedError, + map_error, +) +from azure.core.pipeline import PipelineResponse +from azure.core.polling import LROPoller, NoPolling, PollingMethod +from azure.core.polling.base_polling import LROBasePolling +from azure.core.rest import HttpRequest, HttpResponse +from azure.core.tracing.decorator import distributed_trace +from azure.core.utils import case_insensitive_dict + +from .. import models as _models +from .._configuration import FineTuningSessionClientConfiguration +from .._utils.model_base import SdkJSONEncoder, _deserialize, _failsafe_deserialize +from .._utils.serialization import Deserializer, Serializer +from .._validation import api_version_validation +from ..models._enums import FoundryFeaturesOptInKeys + +JSON = MutableMapping[str, Any] +T = TypeVar("T") +ClsType = Optional[Callable[[PipelineResponse[HttpRequest, HttpResponse], T, dict[str, Any]], Any]] +List = list + +_SERIALIZER = Serializer() +_SERIALIZER.client_side_validation = False + + +def build_sessions_create_request( + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/fine_tuning/sessions" + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + _headers["Foundry-Features"] = _SERIALIZER.header("foundry_features", foundry_features, "str") + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_sessions_list_request( + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + limit: Optional[int] = None, + offset: Optional[int] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/fine_tuning/sessions" + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + if limit is not None: + _params["limit"] = _SERIALIZER.query("limit", limit, "int") + if offset is not None: + _params["offset"] = _SERIALIZER.query("offset", offset, "int") + + # Construct headers + _headers["Foundry-Features"] = _SERIALIZER.header("foundry_features", foundry_features, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_sessions_get_request( + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/fine_tuning/sessions/{sessionId}" + path_format_arguments = { + "sessionId": _SERIALIZER.url("session_id", session_id, "str"), + } + + _url: str = _url.format(**path_format_arguments) # type: ignore + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + _headers["Foundry-Features"] = _SERIALIZER.header("foundry_features", foundry_features, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_sessions_unload_request( + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/fine_tuning/sessions/{sessionId}/complete" + path_format_arguments = { + "sessionId": _SERIALIZER.url("session_id", session_id, "str"), + } + + _url: str = _url.format(**path_format_arguments) # type: ignore + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + _headers["Foundry-Features"] = _SERIALIZER.header("foundry_features", foundry_features, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_sessions_heartbeat_request( + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/fine_tuning/sessions/{sessionId}/heartbeat" + path_format_arguments = { + "sessionId": _SERIALIZER.url("session_id", session_id, "str"), + } + + _url: str = _url.format(**path_format_arguments) # type: ignore + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + _headers["Foundry-Features"] = _SERIALIZER.header("foundry_features", foundry_features, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_training_forward_backward_request( + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/fine_tuning/sessions/{sessionId}/forward_backward" + path_format_arguments = { + "sessionId": _SERIALIZER.url("session_id", session_id, "str"), + } + + _url: str = _url.format(**path_format_arguments) # type: ignore + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + _headers["Foundry-Features"] = _SERIALIZER.header("foundry_features", foundry_features, "str") + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_training_optim_step_request( + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/fine_tuning/sessions/{sessionId}/optim_step" + path_format_arguments = { + "sessionId": _SERIALIZER.url("session_id", session_id, "str"), + } + + _url: str = _url.format(**path_format_arguments) # type: ignore + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + _headers["Foundry-Features"] = _SERIALIZER.header("foundry_features", foundry_features, "str") + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_checkpoints_save_request( + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/fine_tuning/sessions/{sessionId}/checkpoint" + path_format_arguments = { + "sessionId": _SERIALIZER.url("session_id", session_id, "str"), + } + + _url: str = _url.format(**path_format_arguments) # type: ignore + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + _headers["Foundry-Features"] = _SERIALIZER.header("foundry_features", foundry_features, "str") + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_checkpoints_save_sampler_weights_request( # pylint: disable=name-too-long + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/fine_tuning/sessions/{sessionId}/checkpoint_sample" + path_format_arguments = { + "sessionId": _SERIALIZER.url("session_id", session_id, "str"), + } + + _url: str = _url.format(**path_format_arguments) # type: ignore + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + _headers["Foundry-Features"] = _SERIALIZER.header("foundry_features", foundry_features, "str") + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_checkpoints_list_request( + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/fine_tuning/sessions/{sessionId}/checkpoints" + path_format_arguments = { + "sessionId": _SERIALIZER.url("session_id", session_id, "str"), + } + + _url: str = _url.format(**path_format_arguments) # type: ignore + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + _headers["Foundry-Features"] = _SERIALIZER.header("foundry_features", foundry_features, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_checkpoints_get_request( + session_id: str, + checkpoint_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/fine_tuning/sessions/{sessionId}/checkpoints/{checkpointId}" + path_format_arguments = { + "sessionId": _SERIALIZER.url("session_id", session_id, "str"), + "checkpointId": _SERIALIZER.url("checkpoint_id", checkpoint_id, "str"), + } + + _url: str = _url.format(**path_format_arguments) # type: ignore + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + _headers["Foundry-Features"] = _SERIALIZER.header("foundry_features", foundry_features, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_sampling_sample_request( + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/fine_tuning/sessions/{sessionId}/sample" + path_format_arguments = { + "sessionId": _SERIALIZER.url("session_id", session_id, "str"), + } + + _url: str = _url.format(**path_format_arguments) # type: ignore + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + _headers["Foundry-Features"] = _SERIALIZER.header("foundry_features", foundry_features, "str") + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_operations_get_request( + session_id: str, + operation_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/fine_tuning/sessions/{sessionId}/request/{requestId}" + path_format_arguments = { + "sessionId": _SERIALIZER.url("session_id", session_id, "str"), + "requestId": _SERIALIZER.url("operation_id", operation_id, "str"), + } + + _url: str = _url.format(**path_format_arguments) # type: ignore + + # Construct parameters + _params["api-version"] = _SERIALIZER.query("api_version", api_version, "str") + + # Construct headers + _headers["Foundry-Features"] = _SERIALIZER.header("foundry_features", foundry_features, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) + + +class SessionsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~azure.ai.finetuning_sessions.FineTuningSessionClient`'s + :attr:`sessions` attribute. + """ + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: FineTuningSessionClientConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def _create_initial( + self, + body: Union[_models.CreateSessionRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> Iterator[bytes]: + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_sessions_create_request( + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = True + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + response_headers = {} + response_headers["Operation-Location"] = self._deserialize("str", response.headers.get("Operation-Location")) + + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + def begin_create( + self, + body: _models.CreateSessionRequest, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Create a fine-tuning session. + + Create a new fine-tuning session and allocate it to a GPU engine. + + Returns ``session_id`` array ``["session_xxx", "sampling_xxx"]``. + Use ``session_xxx`` as ``{sessionId}`` for training ops and + ``sampling_xxx`` as ``{samplingId}`` for sampling ops. + + :param body: Required. + :type body: ~azure.ai.finetuning_sessions.models.CreateSessionRequest + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def begin_create( + self, + body: JSON, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Create a fine-tuning session. + + Create a new fine-tuning session and allocate it to a GPU engine. + + Returns ``session_id`` array ``["session_xxx", "sampling_xxx"]``. + Use ``session_xxx`` as ``{sessionId}`` for training ops and + ``sampling_xxx`` as ``{samplingId}`` for sampling ops. + + :param body: Required. + :type body: JSON + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def begin_create( + self, + body: IO[bytes], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Create a fine-tuning session. + + Create a new fine-tuning session and allocate it to a GPU engine. + + Returns ``session_id`` array ``["session_xxx", "sampling_xxx"]``. + Use ``session_xxx`` as ``{sessionId}`` for training ops and + ``sampling_xxx`` as ``{samplingId}`` for sampling ops. + + :param body: Required. + :type body: IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def begin_create( + self, + body: Union[_models.CreateSessionRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Create a fine-tuning session. + + Create a new fine-tuning session and allocate it to a GPU engine. + + Returns ``session_id`` array ``["session_xxx", "sampling_xxx"]``. + Use ``session_xxx`` as ``{sessionId}`` for training ops and + ``sampling_xxx`` as ``{samplingId}`` for sampling ops. + + :param body: Is one of the following types: CreateSessionRequest, JSON, IO[bytes] Required. + :type body: ~azure.ai.finetuning_sessions.models.CreateSessionRequest or JSON or IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.OperationResult] = kwargs.pop("cls", None) + polling: Union[bool, PollingMethod] = kwargs.pop("polling", True) + lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) + cont_token: Optional[str] = kwargs.pop("continuation_token", None) + if cont_token is None: + raw_result = self._create_initial( + body=body, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + cls=lambda x, y, z: x, + headers=_headers, + params=_params, + **kwargs + ) + raw_result.http_response.read() # type: ignore + kwargs.pop("error_map", None) + + def get_long_running_output(pipeline_response): + response_headers = {} + response = pipeline_response.http_response + response_headers["Operation-Location"] = self._deserialize( + "str", response.headers.get("Operation-Location") + ) + + deserialized = _deserialize(_models.OperationResult, response.json().get("result", {})) + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + return deserialized + + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + if polling is True: + polling_method: PollingMethod = cast( + PollingMethod, LROBasePolling(lro_delay, path_format_arguments=path_format_arguments, **kwargs) + ) + elif polling is False: + polling_method = cast(PollingMethod, NoPolling()) + else: + polling_method = polling + if cont_token: + return LROPoller[_models.OperationResult].from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output, + ) + return LROPoller[_models.OperationResult]( + self._client, raw_result, get_long_running_output, polling_method # type: ignore + ) + + @distributed_trace + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def create( + self, + body: Union[_models.CreateSessionRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> JSON: + """Create a fine-tuning session (synchronous 200 response). + + POSTs to ``/fine_tuning/sessions`` and returns the response body directly. + Use this instead of ``begin_create`` when the server returns HTTP 200 (not 202 LRO). + + :param body: Is one of the following types: CreateSessionRequest, JSON, IO[bytes] Required. + :type body: ~azure.ai.finetuning_sessions.models.CreateSessionRequest or JSON or IO[bytes] + :keyword foundry_features: Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: dict (JSON response body) + :rtype: dict + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[JSON] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_sessions_create_request( + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + response.read() + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize(_models.ApiErrorResponse, response) + raise HttpResponseError(response=response, model=error) + + deserialized: JSON = response.json() + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def list( + self, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + limit: Optional[int] = None, + offset: Optional[int] = None, + **kwargs: Any + ) -> _models.SessionList: + """List fine-tuning sessions. + + List all fine-tuning sessions for this project. + + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword limit: Default value is None. + :paramtype limit: int + :keyword offset: Default value is None. + :paramtype offset: int + :return: SessionList. The SessionList is compatible with MutableMapping + :rtype: ~azure.ai.finetuning_sessions.models.SessionList + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.SessionList] = kwargs.pop("cls", None) + + _request = build_sessions_list_request( + foundry_features=foundry_features, + api_version=api_version, + limit=limit, + offset=offset, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + if _stream: + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + else: + deserialized = _deserialize(_models.SessionList, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def get( + self, + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> _models.Session: + """Get a fine-tuning session. + + Get information about a specific fine-tuning session. + + :param session_id: Required. + :type session_id: str + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: Session. The Session is compatible with MutableMapping + :rtype: ~azure.ai.finetuning_sessions.models.Session + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.Session] = kwargs.pop("cls", None) + + _request = build_sessions_get_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + if _stream: + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + else: + deserialized = _deserialize(_models.Session, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def _unload_initial( + self, + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> Iterator[bytes]: + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) + + _request = build_sessions_unload_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = True + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + response_headers = {} + response_headers["Operation-Location"] = self._deserialize("str", response.headers.get("Operation-Location")) + + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def begin_unload( + self, + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Unload a fine-tuning session. + + Unload a session from the GPU engine, freeing memory. LoRA weights are lost; save a checkpoint + before calling this. + + :param session_id: Required. + :type session_id: str + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.OperationResult] = kwargs.pop("cls", None) + polling: Union[bool, PollingMethod] = kwargs.pop("polling", True) + lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) + cont_token: Optional[str] = kwargs.pop("continuation_token", None) + if cont_token is None: + raw_result = self._unload_initial( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + cls=lambda x, y, z: x, + headers=_headers, + params=_params, + **kwargs + ) + raw_result.http_response.read() # type: ignore + kwargs.pop("error_map", None) + + def get_long_running_output(pipeline_response): + response_headers = {} + response = pipeline_response.http_response + response_headers["Operation-Location"] = self._deserialize( + "str", response.headers.get("Operation-Location") + ) + + deserialized = _deserialize(_models.OperationResult, response.json().get("result", {})) + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + return deserialized + + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + if polling is True: + polling_method: PollingMethod = cast( + PollingMethod, LROBasePolling(lro_delay, path_format_arguments=path_format_arguments, **kwargs) + ) + elif polling is False: + polling_method = cast(PollingMethod, NoPolling()) + else: + polling_method = polling + if cont_token: + return LROPoller[_models.OperationResult].from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output, + ) + return LROPoller[_models.OperationResult]( + self._client, raw_result, get_long_running_output, polling_method # type: ignore + ) + + @distributed_trace + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def heartbeat( + self, + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> _models.HeartbeatResponse: + """Session heartbeat. + + Heartbeat — refresh an active session to prevent idle expiry. The SDK sends this automatically + every 30 seconds. Returns 404 if the session has already expired. + + :param session_id: Required. + :type session_id: str + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: HeartbeatResponse. The HeartbeatResponse is compatible with MutableMapping + :rtype: ~azure.ai.finetuning_sessions.models.HeartbeatResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.HeartbeatResponse] = kwargs.pop("cls", None) + + _request = build_sessions_heartbeat_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + if _stream: + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + else: + deserialized = _deserialize(_models.HeartbeatResponse, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + +class TrainingOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~azure.ai.finetuning_sessions.FineTuningSessionClient`'s + :attr:`training` attribute. + """ + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: FineTuningSessionClientConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def _forward_backward_initial( + self, + session_id: str, + body: Union[_models.ForwardBackwardRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> Iterator[bytes]: + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_training_forward_backward_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = True + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + response_headers = {} + response_headers["Operation-Location"] = self._deserialize("str", response.headers.get("Operation-Location")) + + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + def begin_forward_backward( + self, + session_id: str, + body: _models.ForwardBackwardRequest, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Forward and backward pass. + + Submit a mini-batch for a combined forward + backward pass. + + Gradients accumulate until an optimizer step is issued. + Poll the returned operation URL for ``ForwardBackwardResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: ~azure.ai.finetuning_sessions.models.ForwardBackwardRequest + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def begin_forward_backward( + self, + session_id: str, + body: JSON, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Forward and backward pass. + + Submit a mini-batch for a combined forward + backward pass. + + Gradients accumulate until an optimizer step is issued. + Poll the returned operation URL for ``ForwardBackwardResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: JSON + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def begin_forward_backward( + self, + session_id: str, + body: IO[bytes], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Forward and backward pass. + + Submit a mini-batch for a combined forward + backward pass. + + Gradients accumulate until an optimizer step is issued. + Poll the returned operation URL for ``ForwardBackwardResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def begin_forward_backward( + self, + session_id: str, + body: Union[_models.ForwardBackwardRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Forward and backward pass. + + Submit a mini-batch for a combined forward + backward pass. + + Gradients accumulate until an optimizer step is issued. + Poll the returned operation URL for ``ForwardBackwardResult``. + + :param session_id: Required. + :type session_id: str + :param body: Is one of the following types: ForwardBackwardRequest, JSON, IO[bytes] Required. + :type body: ~azure.ai.finetuning_sessions.models.ForwardBackwardRequest or JSON or IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.OperationResult] = kwargs.pop("cls", None) + polling: Union[bool, PollingMethod] = kwargs.pop("polling", True) + lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) + cont_token: Optional[str] = kwargs.pop("continuation_token", None) + if cont_token is None: + raw_result = self._forward_backward_initial( + session_id=session_id, + body=body, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + cls=lambda x, y, z: x, + headers=_headers, + params=_params, + **kwargs + ) + raw_result.http_response.read() # type: ignore + kwargs.pop("error_map", None) + + def get_long_running_output(pipeline_response): + response_headers = {} + response = pipeline_response.http_response + response_headers["Operation-Location"] = self._deserialize( + "str", response.headers.get("Operation-Location") + ) + + deserialized = _deserialize(_models.OperationResult, response.json().get("result", {})) + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + return deserialized + + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + if polling is True: + polling_method: PollingMethod = cast( + PollingMethod, LROBasePolling(lro_delay, path_format_arguments=path_format_arguments, **kwargs) + ) + elif polling is False: + polling_method = cast(PollingMethod, NoPolling()) + else: + polling_method = polling + if cont_token: + return LROPoller[_models.OperationResult].from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output, + ) + return LROPoller[_models.OperationResult]( + self._client, raw_result, get_long_running_output, polling_method # type: ignore + ) + + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def _optim_step_initial( + self, + session_id: str, + body: Union[_models.OptimStepRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> Iterator[bytes]: + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_training_optim_step_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = True + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + response_headers = {} + response_headers["Operation-Location"] = self._deserialize("str", response.headers.get("Operation-Location")) + + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + def begin_optim_step( + self, + session_id: str, + body: _models.OptimStepRequest, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Optimizer step. + + Apply accumulated gradients to the LoRA weights using the Adam optimizer. + + Poll the returned operation URL for ``OptimStepResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: ~azure.ai.finetuning_sessions.models.OptimStepRequest + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def begin_optim_step( + self, + session_id: str, + body: JSON, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Optimizer step. + + Apply accumulated gradients to the LoRA weights using the Adam optimizer. + + Poll the returned operation URL for ``OptimStepResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: JSON + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def begin_optim_step( + self, + session_id: str, + body: IO[bytes], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Optimizer step. + + Apply accumulated gradients to the LoRA weights using the Adam optimizer. + + Poll the returned operation URL for ``OptimStepResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def begin_optim_step( + self, + session_id: str, + body: Union[_models.OptimStepRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Optimizer step. + + Apply accumulated gradients to the LoRA weights using the Adam optimizer. + + Poll the returned operation URL for ``OptimStepResult``. + + :param session_id: Required. + :type session_id: str + :param body: Is one of the following types: OptimStepRequest, JSON, IO[bytes] Required. + :type body: ~azure.ai.finetuning_sessions.models.OptimStepRequest or JSON or IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.OperationResult] = kwargs.pop("cls", None) + polling: Union[bool, PollingMethod] = kwargs.pop("polling", True) + lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) + cont_token: Optional[str] = kwargs.pop("continuation_token", None) + if cont_token is None: + raw_result = self._optim_step_initial( + session_id=session_id, + body=body, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + cls=lambda x, y, z: x, + headers=_headers, + params=_params, + **kwargs + ) + raw_result.http_response.read() # type: ignore + kwargs.pop("error_map", None) + + def get_long_running_output(pipeline_response): + response_headers = {} + response = pipeline_response.http_response + response_headers["Operation-Location"] = self._deserialize( + "str", response.headers.get("Operation-Location") + ) + + deserialized = _deserialize(_models.OperationResult, response.json().get("result", {})) + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + return deserialized + + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + if polling is True: + polling_method: PollingMethod = cast( + PollingMethod, LROBasePolling(lro_delay, path_format_arguments=path_format_arguments, **kwargs) + ) + elif polling is False: + polling_method = cast(PollingMethod, NoPolling()) + else: + polling_method = polling + if cont_token: + return LROPoller[_models.OperationResult].from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output, + ) + return LROPoller[_models.OperationResult]( + self._client, raw_result, get_long_running_output, polling_method # type: ignore + ) + + +class CheckpointsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~azure.ai.finetuning_sessions.FineTuningSessionClient`'s + :attr:`checkpoints` attribute. + """ + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: FineTuningSessionClientConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def _save_initial( + self, + session_id: str, + body: Union[_models.SaveCheckpointRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> Iterator[bytes]: + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_checkpoints_save_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = True + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + response_headers = {} + response_headers["Operation-Location"] = self._deserialize("str", response.headers.get("Operation-Location")) + + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + def begin_save( + self, + session_id: str, + body: _models.SaveCheckpointRequest, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Save a training checkpoint. + + Save a training checkpoint (LoRA weights + optimizer state) to blob storage. + + Poll the returned operation URL for ``SaveCheckpointResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: ~azure.ai.finetuning_sessions.models.SaveCheckpointRequest + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def begin_save( + self, + session_id: str, + body: JSON, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Save a training checkpoint. + + Save a training checkpoint (LoRA weights + optimizer state) to blob storage. + + Poll the returned operation URL for ``SaveCheckpointResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: JSON + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def begin_save( + self, + session_id: str, + body: IO[bytes], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Save a training checkpoint. + + Save a training checkpoint (LoRA weights + optimizer state) to blob storage. + + Poll the returned operation URL for ``SaveCheckpointResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def begin_save( + self, + session_id: str, + body: Union[_models.SaveCheckpointRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Save a training checkpoint. + + Save a training checkpoint (LoRA weights + optimizer state) to blob storage. + + Poll the returned operation URL for ``SaveCheckpointResult``. + + :param session_id: Required. + :type session_id: str + :param body: Is one of the following types: SaveCheckpointRequest, JSON, IO[bytes] Required. + :type body: ~azure.ai.finetuning_sessions.models.SaveCheckpointRequest or JSON or IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.OperationResult] = kwargs.pop("cls", None) + polling: Union[bool, PollingMethod] = kwargs.pop("polling", True) + lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) + cont_token: Optional[str] = kwargs.pop("continuation_token", None) + if cont_token is None: + raw_result = self._save_initial( + session_id=session_id, + body=body, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + cls=lambda x, y, z: x, + headers=_headers, + params=_params, + **kwargs + ) + raw_result.http_response.read() # type: ignore + kwargs.pop("error_map", None) + + def get_long_running_output(pipeline_response): + response_headers = {} + response = pipeline_response.http_response + response_headers["Operation-Location"] = self._deserialize( + "str", response.headers.get("Operation-Location") + ) + + deserialized = _deserialize(_models.OperationResult, response.json().get("result", {})) + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + return deserialized + + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + if polling is True: + polling_method: PollingMethod = cast( + PollingMethod, LROBasePolling(lro_delay, path_format_arguments=path_format_arguments, **kwargs) + ) + elif polling is False: + polling_method = cast(PollingMethod, NoPolling()) + else: + polling_method = polling + if cont_token: + return LROPoller[_models.OperationResult].from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output, + ) + return LROPoller[_models.OperationResult]( + self._client, raw_result, get_long_running_output, polling_method # type: ignore + ) + + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def _save_sampler_weights_initial( + self, + session_id: str, + body: Union[_models.SaveSamplerWeightsRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> Iterator[bytes]: + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_checkpoints_save_sampler_weights_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = True + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + response_headers = {} + response_headers["Operation-Location"] = self._deserialize("str", response.headers.get("Operation-Location")) + + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + def begin_save_sampler_weights( + self, + session_id: str, + body: _models.SaveSamplerWeightsRequest, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Save sampler weights. + + Save sampler-compatible weights (no optimizer state) for generation. + + Poll the returned operation URL for ``SaveSamplerWeightsResult``. + The result contains ``sampling_session_id`` — pass this to subsequent sample calls. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: ~azure.ai.finetuning_sessions.models.SaveSamplerWeightsRequest + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def begin_save_sampler_weights( + self, + session_id: str, + body: JSON, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Save sampler weights. + + Save sampler-compatible weights (no optimizer state) for generation. + + Poll the returned operation URL for ``SaveSamplerWeightsResult``. + The result contains ``sampling_session_id`` — pass this to subsequent sample calls. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: JSON + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def begin_save_sampler_weights( + self, + session_id: str, + body: IO[bytes], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Save sampler weights. + + Save sampler-compatible weights (no optimizer state) for generation. + + Poll the returned operation URL for ``SaveSamplerWeightsResult``. + The result contains ``sampling_session_id`` — pass this to subsequent sample calls. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def begin_save_sampler_weights( + self, + session_id: str, + body: Union[_models.SaveSamplerWeightsRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Save sampler weights. + + Save sampler-compatible weights (no optimizer state) for generation. + + Poll the returned operation URL for ``SaveSamplerWeightsResult``. + The result contains ``sampling_session_id`` — pass this to subsequent sample calls. + + :param session_id: Required. + :type session_id: str + :param body: Is one of the following types: SaveSamplerWeightsRequest, JSON, IO[bytes] + Required. + :type body: ~azure.ai.finetuning_sessions.models.SaveSamplerWeightsRequest or JSON or IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.OperationResult] = kwargs.pop("cls", None) + polling: Union[bool, PollingMethod] = kwargs.pop("polling", True) + lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) + cont_token: Optional[str] = kwargs.pop("continuation_token", None) + if cont_token is None: + raw_result = self._save_sampler_weights_initial( + session_id=session_id, + body=body, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + cls=lambda x, y, z: x, + headers=_headers, + params=_params, + **kwargs + ) + raw_result.http_response.read() # type: ignore + kwargs.pop("error_map", None) + + def get_long_running_output(pipeline_response): + response_headers = {} + response = pipeline_response.http_response + response_headers["Operation-Location"] = self._deserialize( + "str", response.headers.get("Operation-Location") + ) + + deserialized = _deserialize(_models.OperationResult, response.json().get("result", {})) + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + return deserialized + + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + if polling is True: + polling_method: PollingMethod = cast( + PollingMethod, LROBasePolling(lro_delay, path_format_arguments=path_format_arguments, **kwargs) + ) + elif polling is False: + polling_method = cast(PollingMethod, NoPolling()) + else: + polling_method = polling + if cont_token: + return LROPoller[_models.OperationResult].from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output, + ) + return LROPoller[_models.OperationResult]( + self._client, raw_result, get_long_running_output, polling_method # type: ignore + ) + + @distributed_trace + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def list( + self, + session_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> _models.CheckpointList: + """List checkpoints. + + List all checkpoints (training and sampler) for this session. + + :param session_id: Required. + :type session_id: str + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: CheckpointList. The CheckpointList is compatible with MutableMapping + :rtype: ~azure.ai.finetuning_sessions.models.CheckpointList + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.CheckpointList] = kwargs.pop("cls", None) + + _request = build_checkpoints_list_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + if _stream: + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + else: + deserialized = _deserialize(_models.CheckpointList, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def get( + self, + session_id: str, + checkpoint_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> _models.CheckpointInfo: + """Get checkpoint info. + + Get metadata for a specific checkpoint. + + :param session_id: Required. + :type session_id: str + :param checkpoint_id: Required. + :type checkpoint_id: str + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: CheckpointInfo. The CheckpointInfo is compatible with MutableMapping + :rtype: ~azure.ai.finetuning_sessions.models.CheckpointInfo + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.CheckpointInfo] = kwargs.pop("cls", None) + + _request = build_checkpoints_get_request( + session_id=session_id, + checkpoint_id=checkpoint_id, + foundry_features=foundry_features, + api_version=api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + if _stream: + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + else: + deserialized = _deserialize(_models.CheckpointInfo, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + +class SamplingOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~azure.ai.finetuning_sessions.FineTuningSessionClient`'s + :attr:`sampling` attribute. + """ + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: FineTuningSessionClientConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def _sample_initial( + self, + session_id: str, + body: Union[_models.SampleRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> Iterator[bytes]: + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore + + _request = build_sampling_sample_request( + session_id=session_id, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = True + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + response_headers = {} + response_headers["Operation-Location"] = self._deserialize("str", response.headers.get("Operation-Location")) + + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + def begin_sample( + self, + session_id: str, + body: _models.SampleRequest, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Generate completions. + + Generate one or more completions using the session's current LoRA weights. + + Requires a prior ``checkpoint_sample`` call to push weights to the sampler. + Pass ``sampling_session_id`` and ``seq_id`` from that call in the request body. + + Poll the returned operation URL for ``FineTuningSampleResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: ~azure.ai.finetuning_sessions.models.SampleRequest + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def begin_sample( + self, + session_id: str, + body: JSON, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Generate completions. + + Generate one or more completions using the session's current LoRA weights. + + Requires a prior ``checkpoint_sample`` call to push weights to the sampler. + Pass ``sampling_session_id`` and ``seq_id`` from that call in the request body. + + Poll the returned operation URL for ``FineTuningSampleResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: JSON + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def begin_sample( + self, + session_id: str, + body: IO[bytes], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + content_type: str = "application/json", + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Generate completions. + + Generate one or more completions using the session's current LoRA weights. + + Requires a prior ``checkpoint_sample`` call to push weights to the sampler. + Pass ``sampling_session_id`` and ``seq_id`` from that call in the request body. + + Poll the returned operation URL for ``FineTuningSampleResult``. + + :param session_id: Required. + :type session_id: str + :param body: Required. + :type body: IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def begin_sample( + self, + session_id: str, + body: Union[_models.SampleRequest, JSON, IO[bytes]], + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> LROPoller[_models.OperationResult]: + """Generate completions. + + Generate one or more completions using the session's current LoRA weights. + + Requires a prior ``checkpoint_sample`` call to push weights to the sampler. + Pass ``sampling_session_id`` and ``seq_id`` from that call in the request body. + + Poll the returned operation URL for ``FineTuningSampleResult``. + + :param session_id: Required. + :type session_id: str + :param body: Is one of the following types: SampleRequest, JSON, IO[bytes] Required. + :type body: ~azure.ai.finetuning_sessions.models.SampleRequest or JSON or IO[bytes] + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: An instance of LROPoller that returns OperationResult. The OperationResult is + compatible with MutableMapping + :rtype: ~azure.core.polling.LROPoller[~azure.ai.finetuning_sessions.models.OperationResult] + :raises ~azure.core.exceptions.HttpResponseError: + """ + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[_models.OperationResult] = kwargs.pop("cls", None) + polling: Union[bool, PollingMethod] = kwargs.pop("polling", True) + lro_delay = kwargs.pop("polling_interval", self._config.polling_interval) + cont_token: Optional[str] = kwargs.pop("continuation_token", None) + if cont_token is None: + raw_result = self._sample_initial( + session_id=session_id, + body=body, + foundry_features=foundry_features, + api_version=api_version, + content_type=content_type, + cls=lambda x, y, z: x, + headers=_headers, + params=_params, + **kwargs + ) + raw_result.http_response.read() # type: ignore + kwargs.pop("error_map", None) + + def get_long_running_output(pipeline_response): + response_headers = {} + response = pipeline_response.http_response + response_headers["Operation-Location"] = self._deserialize( + "str", response.headers.get("Operation-Location") + ) + + deserialized = _deserialize(_models.OperationResult, response.json().get("result", {})) + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + return deserialized + + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + + if polling is True: + polling_method: PollingMethod = cast( + PollingMethod, LROBasePolling(lro_delay, path_format_arguments=path_format_arguments, **kwargs) + ) + elif polling is False: + polling_method = cast(PollingMethod, NoPolling()) + else: + polling_method = polling + if cont_token: + return LROPoller[_models.OperationResult].from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output, + ) + return LROPoller[_models.OperationResult]( + self._client, raw_result, get_long_running_output, polling_method # type: ignore + ) + + +class Operations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~azure.ai.finetuning_sessions.FineTuningSessionClient`'s + :attr:`operations` attribute. + """ + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: FineTuningSessionClientConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @distributed_trace + @api_version_validation( + params_added_on={"virtual-public-preview": ["foundry_features", "api_version"]}, + ) + def get( + self, + session_id: str, + operation_id: str, + *, + foundry_features: Literal[FoundryFeaturesOptInKeys.FINETUNING_SESSIONS_V1_PREVIEW], + api_version: str, + **kwargs: Any + ) -> _models.OperationResult: + """Poll operation status. + + Poll for the result of an async fine-tuning operation (Azure LRO). + + Returns ``status: "running"`` while in progress, ``"succeeded"`` on completion, + or ``"failed"`` on error. When succeeded, the body contains the typed result + (``ForwardBackwardOperationResult``, ``OptimStepOperationResult``, etc.) + discriminated by ``type``. + + :param session_id: Required. + :type session_id: str + :param operation_id: Required. + :type operation_id: str + :keyword foundry_features: A feature flag opt-in required when using preview operations or + modifying persisted preview resources. FINETUNING_SESSIONS_V1_PREVIEW. Required. + :paramtype foundry_features: str or + ~azure.ai.finetuning_sessions.models.FINETUNING_SESSIONS_V1_PREVIEW + :keyword api_version: The API version to use for this operation. Required. + :paramtype api_version: str + :return: OperationResult. The OperationResult is compatible with MutableMapping + :rtype: ~azure.ai.finetuning_sessions.models.OperationResult + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.OperationResult] = kwargs.pop("cls", None) + + _request = build_operations_get_request( + session_id=session_id, + operation_id=operation_id, + foundry_features=foundry_features, + api_version=api_version, + headers=_headers, + params=_params, + ) + path_format_arguments = { + "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + } + _request.url = self._client.format_url(_request.url, **path_format_arguments) + + _decompress = kwargs.pop("decompress", True) + _stream = kwargs.pop("stream", False) + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + if _stream: + try: + response.read() # Load the body in memory and close the socket + except (StreamConsumedError, StreamClosedError): + pass + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = _failsafe_deserialize( + _models.ApiErrorResponse, + response, + ) + raise HttpResponseError(response=response, model=error) + + if _stream: + deserialized = response.iter_bytes() if _decompress else response.iter_raw() + else: + deserialized = _deserialize(_models.OperationResult, response.json()) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/operations/_patch.py b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/operations/_patch.py new file mode 100644 index 000000000000..87676c65a8f0 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/operations/_patch.py @@ -0,0 +1,21 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------- +"""Customize generated code here. + +Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize +""" + + +__all__: list[str] = [] # Add all objects you want publicly available to users at this package level + + +def patch_sdk(): + """Do not remove from this file. + + `patch_sdk` is a last resort escape hatch that allows you to do customizations + you can't accomplish using the techniques described in + https://aka.ms/azsdk/python/dpcodegen/python/customize + """ diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/py.typed b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/py.typed new file mode 100644 index 000000000000..e5aff4f83af8 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/azure/ai/finetuning_sessions/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561. \ No newline at end of file diff --git a/sdk/ai/azure-ai-finetuning-sessions/azure_ai_finetuning_sessions-1.0.0b1-py3-none-any.whl b/sdk/ai/azure-ai-finetuning-sessions/azure_ai_finetuning_sessions-1.0.0b1-py3-none-any.whl new file mode 100644 index 000000000000..7e77bd8109fb Binary files /dev/null and b/sdk/ai/azure-ai-finetuning-sessions/azure_ai_finetuning_sessions-1.0.0b1-py3-none-any.whl differ diff --git a/sdk/ai/azure-ai-finetuning-sessions/dev_requirements.txt b/sdk/ai/azure-ai-finetuning-sessions/dev_requirements.txt new file mode 100644 index 000000000000..ad0907b03b93 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/dev_requirements.txt @@ -0,0 +1,4 @@ +-e ../../../eng/tools/azure-sdk-tools +../../core/azure-core +../../identity/azure-identity +aiohttp \ No newline at end of file diff --git a/sdk/ai/azure-ai-finetuning-sessions/pyproject.toml b/sdk/ai/azure-ai-finetuning-sessions/pyproject.toml new file mode 100644 index 000000000000..bc92a0259c82 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/pyproject.toml @@ -0,0 +1,61 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# Code generated by Microsoft (R) Python Code Generator. +# Changes may cause incorrect behavior and will be lost if the code is regenerated. +# -------------------------------------------------------------------------- + +[build-system] +requires = ["setuptools>=77.0.3", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "azure-ai-finetuning-sessions" +authors = [ + { name = "Microsoft Corporation", email = "azpysdkhelp@microsoft.com" }, +] +description = "Microsoft Corporation Azure Finetuning Sessions Client Library for Python" +license = "MIT" +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +requires-python = ">=3.9" +keywords = ["azure", "azure sdk"] + +dependencies = [ + "isodate>=0.6.1", + "azure-core>=1.37.0", + "typing-extensions>=4.6.0", +] +dynamic = [ +"version", "readme" +] + +[project.urls] +repository = "https://github.com/Azure/azure-sdk-for-python" + +[tool.setuptools.dynamic] +version = {attr = "azure.ai.finetuning_sessions._version.VERSION"} +readme = {file = ["README.md", "CHANGELOG.md"], content-type = "text/markdown"} + +[tool.setuptools.packages.find] +exclude = [ + "tests*", + "generated_tests*", + "samples*", + "generated_samples*", + "doc*", + "azure", + "azure.ai", +] + +[tool.setuptools.package-data] +pytyped = ["py.typed"] diff --git a/sdk/ai/azure-ai-finetuning-sessions/test_smoke.py b/sdk/ai/azure-ai-finetuning-sessions/test_smoke.py new file mode 100644 index 000000000000..e1ba0b11cf71 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/test_smoke.py @@ -0,0 +1,155 @@ +""" +Offline smoke test for azure-ai-finetuning-sessions SDK. +Mirrors the hero code patterns from SPEC_FOUNDRY_AICLIENT.md. +Run: python test_smoke.py +No real endpoint or credentials required. +""" +import time +from azure.core.credentials import AccessToken +from azure.core.pipeline.transport import HttpTransport, HttpResponse as _TransportHttpResponse +from azure.ai.finetuning_sessions import FineTuningSessionClient, FineTuningSession +from azure.ai.finetuning_sessions.models import ( + CreateSessionRequest, + Datum, + ModelInput, + ModelInputChunk, + LossFnInputs, + TensorData, + AdamParams, + LoRAConfig, + SamplingParams, +) + + +class _FakeCredential: + def get_token(self, *scopes, **kwargs): + return AccessToken("fake_token", int(time.time()) + 3600) + + def close(self): + pass + + +class _FakeHttpResponse(_TransportHttpResponse): + """Returns 200 OK with smart bodies: POST→pending, GET→succeeded.""" + + def __init__(self, request): + super().__init__(request, None) + self.status_code = 200 + if getattr(request, 'method', 'POST') == "GET": + self._body = b'{"type": "forward_backward", "operation_id": "op1", "status": "succeeded"}' + else: + self._body = b'{"request_id": "op1", "session_id": "session_xxx", "status": "pending"}' + self.headers = {"content-type": "application/json"} + self.content_type = "application/json" + self.reason = "OK" + + def body(self): + return self._body + + def text(self, encoding="utf-8"): + return self._body.decode(encoding or "utf-8") + + def stream_download(self, pipeline, **kwargs): + yield self._body + + def iter_bytes(self, **kwargs): + yield self._body + + def iter_raw(self, **kwargs): + yield self._body + + def read(self): + return self._body + + def json(self): + import json + return json.loads(self._body) + + def close(self): + pass + + def raise_for_status(self): + pass + + +class _FakeTransport(HttpTransport): + """Intercepts all HTTP calls and returns a fake 202 Accepted — no network needed.""" + + def send(self, request, **kwargs): + return _FakeHttpResponse(request) + + def open(self): pass + def close(self): pass + def __enter__(self): return self + def __exit__(self, *args): self.close() + + +# ── Setup ───────────────────────────────────────────────────────────────────── +print("=== azure-ai-finetuning-sessions smoke test ===") +print("(Based on SPEC_FOUNDRY_AICLIENT.md hero code samples)\n") + +client = FineTuningSessionClient( + endpoint="https://fake", + credential=_FakeCredential(), + transport=_FakeTransport(), +) +session = FineTuningSession(client, session_id="session_xxx") +print(f"✓ FineTuningSession: session_id={session.session_id}") + +assert hasattr(client, "sessions"), "missing: client.sessions" +assert hasattr(client, "training"), "missing: client.training" +assert hasattr(client, "checkpoints"), "missing: client.checkpoints" +assert hasattr(client, "sampling"), "missing: client.sampling" +assert hasattr(client, "operations"), "missing: client.operations" +print("✓ Sub-clients: sessions, training, checkpoints, sampling, operations") + +# ── Build training data ──────────────────────────────────────────────────────── +prompt_ids = [1, 2, 3, 4] +target_ids = [5, 6, 7] +all_ids = prompt_ids + target_ids +weights = [0.0] * len(prompt_ids) + [1.0] * len(target_ids) + +batch = [ + Datum( + model_input=ModelInput(chunks=[ModelInputChunk(tokens=all_ids[:-1])]), + loss_fn_inputs=LossFnInputs( + target_tokens=TensorData(data=[float(t) for t in all_ids[1:]]), + weights=TensorData(data=weights[1:]), + ), + ) +] +print(f"✓ Batch of {len(batch)} Datum built\n") + +# ── Spec Scenario 1: SFT training loop ──────────────────────────────────────── +fb_op = session.forward_backward(batch, loss_fn="cross_entropy") +print(f"✓ fb_op = session.forward_backward(batch, loss_fn='cross_entropy') → {type(fb_op).__name__}") + +opt_op = session.optim_step(AdamParams(learning_rate=1e-4, beta1=0.9, beta2=0.95, eps=1e-12, weight_decay=0.0)) +print(f"✓ opt_op = session.optim_step(AdamParams(learning_rate=1e-4)) → {type(opt_op).__name__}") + +ckpt_op = session.save_weights("sft_piglatin_v1") +print(f"✓ ckpt_op = session.save_weights('sft_piglatin_v1') → {type(ckpt_op).__name__}") + +# ── Spec Scenario 2: RFT sampling ───────────────────────────────────────────── +sampler_op = session.save_weights_for_sampler(seq_id=0) +print(f"✓ sampler_op = session.save_weights_for_sampler(seq_id=0) → {type(sampler_op).__name__}") + +sample_op = session.sample( + prompt_tokens=prompt_ids, + sampling_params=SamplingParams(max_tokens=32, temperature=1.0, top_p=1.0, top_k=-1), + num_samples=4, + sampling_session_id="sampling_abc123", + seq_id=0, + prompt_logprobs=True, +) +print(f"✓ sample_op = session.sample(prompt_tokens, params, num_samples=4) → {type(sample_op).__name__}") + +# ── Session creation body ────────────────────────────────────────────────────── +session_body = CreateSessionRequest( + type="training", + base_model="Qwen/Qwen3-0.6B", + lora_config=LoRAConfig(rank=16), +) +print(f"✓ CreateSessionRequest: base_model={session_body.base_model}, lora_rank={session_body.lora_config.rank}") + +print("\n=== All checks passed ✓ ===") diff --git a/sdk/ai/azure-ai-finetuning-sessions/tests/conftest.py b/sdk/ai/azure-ai-finetuning-sessions/tests/conftest.py new file mode 100644 index 000000000000..9b307a8b6deb --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/tests/conftest.py @@ -0,0 +1,132 @@ +# --------------------------------------------------------------------------- +# Shared fixtures for azure-ai-finetuning-sessions unit tests. +# --------------------------------------------------------------------------- +import time + +import pytest +from azure.core.credentials import AccessToken +from azure.core.pipeline.transport import HttpTransport +from azure.core.pipeline.transport import HttpResponse as _TransportHttpResponse + +from azure.ai.finetuning_sessions import FineTuningSessionClient, FineTuningSession +from azure.ai.finetuning_sessions.models import ( + Datum, + ModelInput, + ModelInputChunk, + LossFnInputs, + TensorData, +) + + +# ── Fake credential ────────────────────────────────────────────────────────── + +class FakeCredential: + def get_token(self, *scopes, **kwargs): + return AccessToken("fake_token", int(time.time()) + 3600) + + def close(self): + pass + + +# ── Fake HTTP transport ─────────────────────────────────────────────────────── + +class FakeHttpResponse(_TransportHttpResponse): + """Returns 200 OK with smart request/result bodies for POST vs GET.""" + + def __init__(self, request, body: bytes = None, status_code: int = 200): + super().__init__(request, None) + self.status_code = status_code + if body is None: + if getattr(request, 'method', 'POST') == "GET": + body = b'{"type": "forward_backward", "operation_id": "req1", "status": "succeeded"}' + else: + body = b'{"request_id": "req1", "session_id": "session_test", "status": "pending"}' + self.headers = {"content-type": "application/json"} + self.content_type = "application/json" + self.reason = "OK" + self._body = body + + def body(self): + return self._body + + def text(self, encoding=None): + return self._body.decode(encoding or "utf-8") + + def stream_download(self, pipeline, **kwargs): + yield self._body + + def iter_bytes(self, **kwargs): + yield self._body + + def iter_raw(self, **kwargs): + yield self._body + + def read(self): + return self._body + + def json(self): + import json + return json.loads(self._body) + + def close(self): + pass + + def raise_for_status(self): + pass + + +class FakeTransport(HttpTransport): + """Captures all outgoing requests and returns configurable fake responses.""" + + def __init__(self, response_body: bytes = None, status_code: int = 200): + self.requests: list = [] + self._response_body = response_body + self._status_code = status_code + + def send(self, request, **kwargs): + self.requests.append(request) + return FakeHttpResponse(request, self._response_body, self._status_code) + + def open(self): pass + def close(self): pass + def __enter__(self): return self + def __exit__(self, *args): self.close() + + +# ── Pytest fixtures ─────────────────────────────────────────────────────────── + +@pytest.fixture +def transport(): + return FakeTransport() + + +@pytest.fixture +def client(transport): + return FineTuningSessionClient( + endpoint="https://fake", + credential=FakeCredential(), + transport=transport, + ) + + +@pytest.fixture +def session(client): + return FineTuningSession(client, session_id="session_test") + + +@pytest.fixture +def batch(): + """A minimal single-datum training batch.""" + prompt_ids = [1, 2, 3, 4] + target_ids = [5, 6, 7] + all_ids = prompt_ids + target_ids + weights = [0.0] * len(prompt_ids) + [1.0] * len(target_ids) + return [ + Datum( + model_input=ModelInput(chunks=[ModelInputChunk(tokens=all_ids[:-1])]), + loss_fn_inputs=LossFnInputs( + target_tokens=TensorData(data=[float(t) for t in all_ids[1:]]), + weights=TensorData(data=weights[1:]), + ), + ) + ] diff --git a/sdk/ai/azure-ai-finetuning-sessions/tests/test_finetuning_session.py b/sdk/ai/azure-ai-finetuning-sessions/tests/test_finetuning_session.py new file mode 100644 index 000000000000..41bf32862a89 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/tests/test_finetuning_session.py @@ -0,0 +1,177 @@ +# --------------------------------------------------------------------------- +# Unit tests for FineTuningSession (_patch.py convenience wrapper). +# --------------------------------------------------------------------------- +import json + +import pytest + +from azure.ai.finetuning_sessions import FineTuningSession +from azure.ai.finetuning_sessions.models import ( + AdamParams, + LoRAConfig, + CreateSessionRequest, + OperationResult, + SamplingParams, +) + + +class TestFineTuningSessionInstantiation: + + def test_session_id_stored(self, client): + s = FineTuningSession(client, session_id="abc123") + assert s.session_id == "abc123" + + def test_client_stored(self, client): + s = FineTuningSession(client, session_id="abc123") + assert s._client is client + + +class TestSubClients: + + def test_all_sub_clients_present(self, client): + for name in ("sessions", "training", "checkpoints", "sampling", "operations"): + assert hasattr(client, name), f"missing: client.{name}" + + +class TestForwardBackward: + + def test_returns_operation_result(self, session, batch): + result = session.forward_backward(batch, loss_fn="cross_entropy") + assert isinstance(result, OperationResult) + + def test_default_loss_fn_is_cross_entropy(self, session, batch, transport): + session.forward_backward(batch) + req = transport.requests[0] + body = json.loads(req.body) + assert body["forward_backward_input"]["loss_fn"] == "cross_entropy" + + def test_custom_loss_fn(self, session, batch, transport): + session.forward_backward(batch, loss_fn="dpo") + req = transport.requests[0] + body = json.loads(req.body) + assert body["forward_backward_input"]["loss_fn"] == "dpo" + + def test_request_targets_correct_session(self, session, batch, transport): + session.forward_backward(batch) + req = transport.requests[0] + assert "session_test" in req.url + + def test_preview_header_sent(self, session, batch, transport): + session.forward_backward(batch) + req = transport.requests[0] + assert req.headers.get("Foundry-Features") == "FineTuningSessions=V1Preview" + + +class TestOptimStep: + + def test_returns_operation_result(self, session): + params = AdamParams(learning_rate=1e-4, beta1=0.9, beta2=0.95, eps=1e-12, weight_decay=0.0) + result = session.optim_step(params) + assert isinstance(result, OperationResult) + + def test_request_contains_adam_params(self, session, transport): + params = AdamParams(learning_rate=2e-5, beta1=0.9, beta2=0.95, eps=1e-12, weight_decay=0.01) + session.optim_step(params) + req = transport.requests[0] + body = json.loads(req.body) + assert body["adam_params"]["learning_rate"] == pytest.approx(2e-5) + + def test_request_targets_correct_session(self, session, transport): + session.optim_step(AdamParams(learning_rate=1e-4, beta1=0.9, beta2=0.95, eps=1e-12, weight_decay=0.0)) + assert "session_test" in transport.requests[0].url + + +class TestSaveWeights: + + def test_returns_operation_result(self, session): + assert isinstance(session.save_weights("my_ckpt"), OperationResult) + + def test_request_contains_path(self, session, transport): + session.save_weights("sft_piglatin_v1") + body = json.loads(transport.requests[0].body) + assert body["path"] == "sft_piglatin_v1" + + +class TestSaveWeightsForSampler: + + def test_returns_operation_result(self, session): + assert isinstance(session.save_weights_for_sampler(seq_id=0), OperationResult) + + def test_request_contains_seq_id(self, session, transport): + session.save_weights_for_sampler(seq_id=7) + body = json.loads(transport.requests[0].body) + assert body["seq_id"] == 7 + + def test_optional_path(self, session, transport): + session.save_weights_for_sampler(seq_id=0, path="explicit_path") + body = json.loads(transport.requests[0].body) + assert body["path"] == "explicit_path" + + +class TestSample: + + def test_returns_operation_result(self, session): + result = session.sample( + prompt_tokens=[1, 2, 3], + sampling_params=SamplingParams(max_tokens=16, temperature=1.0, top_p=1.0, top_k=-1), + num_samples=2, + ) + assert isinstance(result, OperationResult) + + def test_request_contains_num_samples(self, session, transport): + session.sample( + prompt_tokens=[1, 2, 3], + sampling_params=SamplingParams(max_tokens=16, temperature=1.0, top_p=1.0, top_k=-1), + num_samples=4, + ) + body = json.loads(transport.requests[0].body) + assert body["num_samples"] == 4 + + def test_request_contains_prompt_tokens(self, session, transport): + session.sample( + prompt_tokens=[10, 20, 30], + sampling_params=SamplingParams(max_tokens=16, temperature=1.0, top_p=1.0, top_k=-1), + ) + body = json.loads(transport.requests[0].body) + tokens = body["prompt"]["chunks"][0]["tokens"] + assert tokens == [10, 20, 30] + + def test_prompt_logprobs_default_false(self, session, transport): + session.sample( + prompt_tokens=[1, 2], + sampling_params=SamplingParams(max_tokens=8, temperature=1.0, top_p=1.0, top_k=-1), + ) + body = json.loads(transport.requests[0].body) + assert body.get("promptLogprobs", False) is False + + +class TestHeartbeat: + + def test_request_targets_correct_session(self, session, transport): + # heartbeat is a synchronous POST — returns 200 with a HeartbeatResponse body + transport._status_code = 200 + transport._response_body = b'{"session_id": "session_test"}' + session.heartbeat() + assert "session_test" in transport.requests[0].url + + +class TestClose: + + def test_returns_none(self, session): + assert session.close() is None + + def test_request_targets_correct_session(self, session, transport): + session.close() + assert "session_test" in transport.requests[0].url + + +class TestCreateSessionRequest: + + def test_fields(self): + req = CreateSessionRequest( + type="training", + base_model="Qwen/Qwen3-0.6B", + lora_config=LoRAConfig(rank=16), + ) + assert req.base_model == "Qwen/Qwen3-0.6B" + assert req.lora_config.rank == 16 diff --git a/sdk/ai/azure-ai-finetuning-sessions/tsp-location.yaml b/sdk/ai/azure-ai-finetuning-sessions/tsp-location.yaml new file mode 100644 index 000000000000..810bfaea3a10 --- /dev/null +++ b/sdk/ai/azure-ai-finetuning-sessions/tsp-location.yaml @@ -0,0 +1,6 @@ +directory: specification/ai-foundry/data-plane/Foundry/src/sdk-python-azure-ai-finetuning-sessions +commit: e8a4289ad7a07c4887d5cb4a4a967e61dbaeb365 +repo: Azure/azure-rest-api-specs-pr +additionalDirectories: +- specification/ai-foundry/data-plane/Foundry/src/session-finetuning +- specification/ai-foundry/data-plane/Foundry/src/common