# note for davis dataloader later: temporally consistent depth estimator: https://github.com/yu-li/TCMonoDepth
# note for cool idea of not even downloading data and just streaming from youtube:https://gist.github.com/Mxhmovd/41e7690114e7ddad8bcd761a76272cc3
import matplotlib.pyplot as plt; 
import cv2
import os
import multiprocessing as mp
import torch.nn.functional as F
import torch
import random
import imageio
import numpy as np
from glob import glob
from collections import defaultdict
from pdb import set_trace as pdb
from itertools import combinations
from random import choice
import matplotlib.pyplot as plt
import imageio.v3 as iio

from torchvision import transforms

import sys

from glob import glob
import os
import gzip
import json
import numpy as np
from data import common

val_seqs=['106_12686_26118', '134_15449_31106', '157_17287_33549', '167_18192_34876', '185_19994_39389', '194_20893_40953', '194_20931_42673', '194_20957_44512', '216_22815_47624', '216_22855_49631', '216_22862_49638', '250_26738_53523', '250_26765_54265', '250_26776_55178', '268_28507_57781', '286_30167_59426', '286_30175_59434', '304_31873_60475', '304_31883_60483', '336_34794_62331', '340_35384_65209', '344_35916_66357', '349_36579_68955', '411_56017_108225', '411_56077_108567', '415_57123_110212', '415_57140_110207', '415_57155_110229', '415_57187_110492', '417_57591_110844', '417_57631_110999', '417_57728_111393', '417_57844_111777', '421_58399_112544', '421_58448_112675', '423_58957_114450', '427_59918_115773', '427_60021_116224', '429_60367_116963', '429_60424_117419', '431_60836_118051', '431_60892_118397', '431_60948_118761', '433_61365_119703', '433_61495_120185', '433_61517_120210', '435_61832_120902', '437_62476_122841', '437_62533_123474', '439_62865_124110', '439_62905_124433', '462_65394_128333', '464_65742_128863', '469_66154_130018', '469_66193_130547', '471_66571_130950', '473_66956_131589', '478_67998_133378', '481_68717_133663', '481_68788_134262', '483_69190_135212', '491_70268_136787', '491_70290_136818', '494_70850_138568', '505_72714_141560', '516_74071_144574', '522_74995_146028', '528_76522_148065', '536_77983_150920', '554_79561_156370', '564_81568_161173', '570_83088_163219', '572_83696_165431', '576_84834_166729', '576_84898_167564', '578_85433_168116', '590_88811_174914', '596_91118_180844', '605_94545_187421', '607_95465_190320', '613_98189_195690', '101_11763_21624', '118_13854_28178', '149_16577_31934', '159_17491_33328', '187_20178_35733', '187_20226_39049', '208_22015_45658', '218_23004_47612', '246_26303_51369', '252_27075_55161', '270_28813_57820', '338_34922_64179', '34_1403_4393', '34_1423_4310', '341_35498_65210', '346_36136_66671', '350_36838_69222', '353_37347_70262', '360_37807_71336', '360_37830_72061', '363_38325_72995', '363_38434_73498', '363_38697_74297', '372_40983_81668', '372_41166_81967', '373_41527_82918', '373_41604_83177', '373_41785_83427', '374_41919_83809', '374_42249_84573', '375_42473_85077', '375_42540_85340', '375_42759_85613', '376_42852_85907', '377_43398_86309', '377_43567_86786', '378_43870_87344', '378_44092_87860', '379_44364_88385', '379_44431_88484', '379_44685_89131', '38_1658_5017', '38_1690_5065', '38_1698_5083', '380_44886_89716', '380_44957_89961', '380_45107_90249', '385_45498_91067', '385_45638_91286', '385_45804_91622', '386_45904_91718', '386_46099_92033', '386_46187_92177', '387_46408_92562', '387_46442_92622', '387_46530_92758', '387_46678_93017', '387_46702_92994', '387_46761_93082', '387_46859_93209', '391_46937_93449', '391_47030_93636', '391_47206_93935', '391_47315_94110', '392_47492_94463', '392_47579_94598', '392_47677_94730', '392_47699_94759', '392_47747_94814', '392_47780_94862', '392_47817_94915', '392_47845_94954', '393_47891_95143', '393_48140_95425', '393_48209_95525', '393_48233_95665', '393_48279_95718', '393_48323_95691', '393_48340_95728', '394_48389_95796', '394_48432_95855', '394_48469_95891', '394_48504_95974', '394_48544_96020', '394_48675_96453', '394_48760_96429', '395_48887_96800', '395_49023_96983', '395_49152_97149', '395_49230_97254', '395_49330_97363', '40_1824_5449', '40_1890_5789', '44_2231_6723', '442_63416_125375', '444_63735_125799', '444_63827_126039', '457_64431_126956', '457_64639_127339', '46_2569_7511', '463_65462_127953', '463_65687_128734', '470_66369_130447', '472_66666_130811', '472_66711_131009', '480_68585_133454', '482_69101_134748', '484_69474_135705', '492_70563_136726', '495_71074_137995', '501_71941_139711', '504_72466_140637', '504_72542_140774', '508_73233_142600', '514_73873_144048', '523_75131_146050', '523_75289_146569', '523_75467_147310', '529_76802_148372', '534_77577_150665', '537_78208_151858', '540_78895_152722', '540_79201_153392', '555_79813_155010', '555_79857_155192', '558_80582_158316', '563_81281_160760', '569_82612_163009', '571_83194_164346', '571_83374_165261', '575_84333_166457', '575_84513_166963', '575_84639_167249', '577_85044_168324', '579_85863_169821', '581_86292_171292', '586_87436_173292', '589_88438_176557', '595_90395_180050', '598_91836_182441', '601_92903_186232', '604_94242_189257', '606_94755_190139', '610_96535_193187', '612_97430_195286', '612_97818_196786', '614_98406_197618', '616_99363_199427', '620_101025_202361', '620_101542_203806', '429_60423_117417', '110_13072_25709', '12_92_460', '138_15853_30823', '161_17676_33269', '171_18574_33930', '171_18655_34299', '189_20388_37651', '198_21279_40528', '198_21304_43212', '210_22197_44976', '210_22214_46391', '220_23201_48799', '239_25238_50942', '254_27355_54486', '272_29062_56598', '272_29077_57710', '290_30767_58613', '308_32475_60385', '333_33901_61867', '339_35151_63581', '339_35245_64147', '341_35530_65361', '342_35828_65628', '350_36616_67878', '350_36691_68250', '353_37215_69638', '353_37459_70512', '360_37802_71331', '360_37829_72153', '360_37856_72450', '363_38335_72833', '363_38506_73861', '39_1748_5157', '39_1760_5220', '39_1770_5180', '39_1778_5192', '39_1801_5229', '39_1812_5251', '41_1928_5845', '41_1944_5742', '41_1957_5880', '41_1975_6051', '43_2132_6344', '43_2152_6520', '43_2162_6539', '43_2171_6635', '43_2186_6567', '43_2210_6597', '442_63337_125204', '45_2362_6728', '45_2378_6819', '45_2401_6970', '455_64243_126554', '457_64504_127090', '463_65625_128625', '465_65979_129472', '47_2681_7386', '470_66364_130439', '472_66750_131145', '477_67657_132286', '482_68990_134171', '492_70639_136955', '495_71221_138640', '501_71942_139718', '506_72878_141668', '508_73217_142532', '520_74658_145293', '523_75276_146542', '523_75505_147459', '529_76941_148926', '537_78125_151144', '540_78981_153019', '555_80068_156081', '558_80651_158549', '563_81449_161728', '571_83156_164166', '573_83799_166138', '577_84985_168004', '581_86245_171089', '581_86525_171977', '586_87444_173366', '589_88289_175266', '589_88466_176662', '601_92673_184931', '601_92889_186162', '604_94155_188780', '604_94318_189617', '606_95001_191248', '608_95810_192878', '610_96820_194015', '612_97798_196738', '616_99372_199396', '618_100343_200816', '620_101125_202675', '620_101402_203411', '113_13354_23676', '113_13369_24788', '123_14351_27841', '141_16151_31077', '164_17976_32964', '174_18886_34381', '192_20692_38357', '20_708_1559', '201_21576_40168', '201_21611_43651', '201_21619_43660', '201_21630_44257', '213_22506_45974', '213_22514_45987', '213_22523_46002', '213_22531_46012', '213_22540_46084', '223_23497_47185', '223_23509_48253', '223_23523_48362', '223_23533_48732', '223_23548_48775', '223_23556_48787', '223_23564_48877', '232_24393_48734', '242_25697_51075', '257_27800_53520', '257_27835_55483', '275_29524_57236', '28_957_2053', '30_1092_3391', '30_1224_3607', '31_1322_4034', '311_32921_60033', '311_32934_61313', '339_35133_63314', '341_35581_65545', '350_36861_69262', '353_37276_70049', '353_37547_70672', '372_41041_81726', '372_41285_82241', '373_41609_83065', '373_41822_83472', '374_42083_84239', '374_42224_84429', '375_42494_85111', '375_42649_85428', '375_42763_85617', '377_43471_86540', '377_43608_86849', '377_43823_87193', '378_44141_87986', '378_44313_88274', '379_44556_88894', '380_44953_89944', '380_45149_90317', '385_45446_90913', '385_45703_91392', '386_46267_92282', '387_46471_92682', '387_46624_92891', '391_46868_93637', '391_47093_93758', '392_47693_94753', '392_47844_94949', '393_47996_95216', '393_48157_95662', '393_48326_95696', '394_48441_95932', '394_48659_96205', '394_48801_96517', '395_48924_96844', '395_49196_97216', '397_50211_98675', '455_64253_126582', '463_65568_128301', '465_65965_129413', '472_66704_130962', '480_68495_133046', '482_69045_134527', '495_71072_137993', '501_72098_140327', '504_72678_141360', '514_73757_143538', '523_75335_146711', '529_76713_148234', '534_77408_149774', '540_78928_152877', '555_79750_154825', '558_80568_158237', '563_81472_161812', '571_83280_164655', '575_84456_166831', '579_85674_169243', '581_86592_172191', '586_87552_174017', '589_88480_176980', '595_90531_180549', '598_92185_184635', '601_92966_186762', '606_95132_191667', '612_97498_195590', '612_97794_196689', '106_12699_27006', '157_17274_33069', '167_18179_34134', '185_19995_39416', '194_20904_41101', '194_20918_42207', '194_20933_42347', '206_21811_45891', '216_22804_46891', '216_22853_49629', '244_25996_51881', '250_26754_54241', '250_26766_54266', '250_26784_55267', '250_26792_55276', '250_26799_55283', '268_28461_57272', '336_34824_63475', '349_36525_68608', '408_55370_106423', '408_55506_106872', '411_55954_107664', '411_56021_108239', '411_56071_108565', '415_57038_109649', '415_57120_110106', '417_57590_110843', '417_57725_111332', '417_57851_111792', '421_58397_112541', '421_58542_113114', '423_58857_113660', '427_59903_115605', '427_60019_116217', '429_60370_116966', '429_60520_117795', '431_60888_118393', '431_60960_118820', '433_61351_119571', '433_61412_120038', '433_61423_120054', '433_61588_120584', '435_61940_121682', '437_62337_122307', '437_62351_122283', '437_62364_122306', '437_62373_122317', '437_62387_122529', '437_62477_122826', '439_62832_124009', '439_62892_124345', '462_65349_127869', '464_65779_129397', '469_66182_130573', '473_66942_131534', '478_67945_133213', '478_67976_133276', '478_68019_133464', '481_68799_134296', '491_70256_136617', '494_70799_137722', '503_72336_140773', '505_72756_142249', '519_74505_144810', '519_74539_145211', '522_75017_146272', '528_76514_147836', '528_76564_149423', '539_78742_152913', '554_79529_154895', '554_79585_157148', '557_80309_158943', '567_82283_162335', '570_83091_163316', '570_83110_163593', '572_83643_164555', '574_84236_166119', '576_84878_167067', '580_86019_168902', '580_86074_169660', '582_86650_171058', '587_87896_173883', '587_87908_173978', '590_88833_175222', '590_88901_176554', '590_88968_178574', '593_89975_179761', '596_91125_180863', '599_92255_182422', '599_92376_184674', '605_94647_188714', '609_96429_192229', '423_58817_113362', '164_17975_32694', '201_21632_44332', '284_29970_57895', '366_39259_75842', '373_41530_82922', '377_43434_86430', '379_44451_88569', '379_44800_89451', '380_45230_90458', '386_46224_92218', '387_46497_92714', '391_46877_93332', '392_47622_94711', '393_48106_95357', '395_48888_96683', '396_49693_97888', '398_50633_99375', '399_51087_100214', '401_51898_101697', '402_52379_102561', '402_52398_102649', '402_52415_102616', '402_52643_103123', '403_52905_103320', '403_53071_103651', '404_53621_104608', '405_53872_105085', '407_54875_106103', '407_54922_106016', '407_54948_106083', '407_55086_106813', '407_55136_106925', '409_55528_107194', '414_56823_109421', '416_57375_110589', '42_2008_6247', '420_58317_112880', '424_59118_114494', '426_59760_116314', '428_60231_117549', '430_60689_118622', '44_2239_6852', '482_69175_134999', '506_72993_141867', '51_3016_8424', '51_3024_8440', '529_76795_148361', '540_79285_153703', '577_84975_167971', '592_89278_179154', '604_94190_189035', '220_23219_48924', '350_36741_68634', '366_39254_75637', '366_39353_76420', '366_39389_76769', '368_39874_78406', '369_40211_78807', '369_40265_78962', '370_40563_80268', '373_41757_83361', '374_42273_84516', '377_43512_86660', '378_43994_87788', '378_44052_87781', '380_45048_90128', '385_45635_91275', '391_47032_93657', '398_50653_99411', '399_51213_100560', '404_53460_104236', '414_56915_110208', '424_59144_114826', '426_59795_116456', '434_61686_121386', '442_63479_125458', '492_70537_136719', '575_84700_167514', '102_11975_21793', '119_13949_27606', '150_16701_32204', '160_17625_33698', '160_17632_33568', '188_20295_35748', '188_20354_39011', '188_20368_39435', '197_21184_40996', '197_21203_41598', '197_21210_41881', '197_21223_42168', '219_23117_47831', '228_23994_49755', '253_27199_53810', '253_27238_55312', '271_28932_56626', '271_28941_56656', '271_28957_57359', '271_28974_57650', '271_28984_57800', '289_30621_58908', '307_32330_60383', '307_32339_60417', '307_32346_60426', '339_35047_62952', '339_35211_63984', '34_1447_4494', '34_1492_4822', '341_35592_65623', '346_36081_66403', '351_37073_67649', '354_37668_70185', '372_41038_81732', '372_41253_82181', '373_41474_82835', '373_41673_83178', '373_41798_83514', '374_41950_83899', '374_41993_84073', '374_42116_84279', '374_42159_84431', '374_42201_84378', '374_42259_84589', '374_42328_84601', '375_42463_85070', '375_42572_85395', '375_42755_85609', '376_42883_85952', '377_43479_86551', '377_43620_86869', '377_43676_87025', '377_43762_87114', '378_43968_87632', '378_44065_87805', '378_44194_88148', '379_44407_88531', '379_44523_88747', '379_44694_89085', '38_1681_5047', '38_1727_5128', '385_45499_91077', '40_1844_5618', '40_1866_5696', '40_1874_5707', '40_1896_5809', '402_52633_103103', '402_52656_103125', '42_2036_6308', '42_2061_6354', '42_2092_6411', '42_2118_6508', '44_2228_6701', '442_63286_125078', '442_63480_125459', '444_63734_125797', '455_64217_126513', '455_64319_126678', '457_64458_127021', '46_2488_7393', '46_2592_7542', '463_65464_127940', '465_65869_129003', '470_66260_129942', '470_66378_130462', '472_66660_130786', '475_67182_131863', '477_67648_132231', '480_68470_132929', '482_68924_133884', '484_69382_135210', '492_70565_136731', '492_70612_136879', '492_70647_137042', '495_71013_137787', '495_71246_138655', '498_71573_139192', '501_72071_140275', '504_72471_140673', '506_72858_141575', '508_73237_142609', '514_73654_142932', '514_73724_143495', '520_74623_145020', '520_74636_145117', '523_75291_146574', '523_75493_147407', '529_76740_148279', '529_76931_148809', '529_77005_149385', '534_77398_149722', '534_77579_150676', '537_78219_152016', '540_79183_153341', '555_79670_154671', '555_79936_155415', '558_80453_157693', '558_80497_157844', '558_80564_158231', '563_81329_160869', '563_81474_161817', '569_82585_162917', '571_83203_164356', '571_83408_165365', '575_84417_166725', '575_84707_167597', '575_84718_167600', '577_85194_168746', '579_85849_169754', '581_86212_170747', '586_87359_172876', '589_88400_176133', '595_90522_180526', '598_91719_182282', '598_92167_184497', '601_92682_184957', '601_92804_185569', '604_93940_187668', '604_94139_188590', '606_94834_190512', '606_95114_191619', '610_96631_193649', '610_96917_194306', '610_97041_194639', '610_97065_194728', '610_97075_194745', '610_97082_194754', '610_97155_194927', '612_97524_195818', '612_97692_196260', '612_97835_196825', '616_99454_199681', '109_12985_24804', '150_16686_31921', '188_20349_38443', '219_23118_48730', '247_26472_51818', '253_27232_54734', '28_932_2976', '31_1266_3780', '339_35216_64044', '351_37001_67542', '372_41020_81814', '372_41268_82206', '373_41573_83003', '373_41817_83467', '375_42373_84845', '377_43419_86667', '377_43817_87186', '378_44318_88279', '380_44997_89975', '385_45383_90817', '386_46288_92363', '387_46785_93116', '392_47387_94384', '393_48098_95407', '396_49384_97438', '396_49684_97872', '397_50249_98714', '399_50983_100007', '399_51088_100216', '399_51242_100634', '401_52072_102142', '401_52219_102322', '402_52660_103133', '404_53466_104260', '404_53658_104654', '404_53792_104918', '407_54885_105726', '407_55069_106724', '409_55534_107183', '410_55781_108227', '414_56817_109442', '414_56902_110062', '416_57417_111018', '420_58288_112769', '420_58358_113182', '422_58754_114424', '428_60145_116853', '430_60647_118261', '430_60682_118523', '430_60744_119178', '432_61242_120328', '434_61678_121347', '436_62155_122349', '436_62295_123577', '444_63680_125652', '463_65590_128361', '48_2699_7889', '48_2718_7933', '48_2727_7938', '480_68481_133013', '495_70993_137702', '50_2904_8607', '50_2911_8619', '50_2920_8627', '50_2928_8645', '50_2946_8867', '50_2958_8877', '50_2970_9015', '50_2977_9019', '50_2986_9056', '50_2995_9134', '504_72648_141222', '520_74650_145188', '534_77455_149990', '540_78956_152968', '555_79708_154767', '563_81168_159798', '571_83281_164669', '577_85080_168386', '598_92039_183502', '606_94846_190590', '610_96781_193965', '612_97773_196583', '616_99287_198871', '618_100344_200850', '108_12867_22800', '118_13872_29491', '159_17497_33508', '169_18397_34794', '196_21101_41593', '196_21117_42823', '208_22006_45546', '218_23001_47455', '237_24940_51353', '252_27042_53471', '270_28762_56391', '270_28784_57220', '288_30476_59554', '339_35059_63052', '341_35566_65497', '351_37047_67839', '353_37373_70321', '354_37610_70059', '372_40889_81028', '372_41070_81770', '372_41198_82135', '372_41322_82305', '373_41501_82878', '373_41671_83205', '374_41862_83720', '374_42064_84229', '374_42322_84591', '375_42471_85182', '375_42652_85433', '375_42808_85676', '377_43492_86700', '377_43644_86912', '377_43772_87134', '378_43975_87639', '379_44584_88939', '379_44820_89503', '380_44910_90063', '380_45074_90196', '380_45156_90335', '385_45750_91462', '386_45994_91883', '386_46163_92141', '386_46244_92256', '387_46642_92916', '391_46973_93666', '391_47245_93974', '392_47371_94303', '392_47539_94545', '393_47888_95038', '394_48564_96090', '395_49096_97096', '396_49447_97596', '396_49613_97768', '396_49689_97877', '397_49951_98349', '397_50120_98557', '397_50201_98664', '397_50299_98775', '398_50398_98964', '398_50562_99253', '398_50646_99429', '398_50665_99436', '398_50701_99505', '398_50733_99548', '398_50765_99605', '398_50813_99730', '399_51048_100100', '400_51401_100980', '401_51933_101782', '401_52176_102265', '401_52192_102282', '401_52322_102442', '402_52400_102577', '402_52486_102870', '403_52901_103244', '404_53526_104366', '405_53898_105114', '405_54019_105370', '405_54162_105559', '442_63444_125406', '463_65651_128614', '470_66258_129935', '472_66820_131412', '480_68592_133469', '484_69390_135216', '492_70688_137461', '495_71085_138106', '501_71937_139766', '504_72515_140724', '506_72995_141936', '517_74134_144494', '523_75363_146965', '529_77012_149393', '540_78951_152958', '555_79682_154727', '558_80571_158241', '563_81163_159786', '563_81482_161832', '571_83259_164593', '573_83723_165788', '575_84681_167426', '579_85800_169676', '581_86278_171264', '581_86537_171992', '586_87419_173247', '586_87430_173273', '589_88112_174229', '589_88135_174300', '589_88232_175069', '589_88593_177585', '589_88646_178210', '592_89385_179812', '598_92069_183719', '601_92762_185374', '604_93935_187661', '606_94887_190835', '608_95665_192087', '612_97578_195861', '612_97926_197184', '616_99293_199077', '618_100332_200748', '62_4316_10771', '62_4319_10728', '62_4324_11285', '62_4327_10713', '62_4345_10704']

from PIL import Image
def _load_16big_png_depth(depth_png) -> np.ndarray:
    with Image.open(depth_png) as depth_pil:
        # the image is stored with 16-bit depth but PIL reads it as I (32 bit).
        # we cast it to uint16, then reinterpret as float16, then cast to float32
        depth = (
            np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
            .astype(np.float32)
            .reshape((depth_pil.size[1], depth_pil.size[0]))
        )
    return depth
def _load_depth(path, scale_adjustment) -> np.ndarray:
    d = _load_16big_png_depth(path) * scale_adjustment
    d[~np.isfinite(d)] = 0.0
    return d[None]  # fake feature channel

# Geometry functions below used for calculating depth, ignore
def glob_imgs(path):
    imgs = []
    for ext in ["*.png", "*.jpg", "*.JPEG", "*.JPG"]:
        imgs.extend(glob(os.path.join(path, ext)))
    return imgs


def pick(list, item_idcs):
    if not list:
        return list
    return [list[i] for i in item_idcs]


def parse_intrinsics(intrinsics):
    fx = intrinsics[..., 0, :1]
    fy = intrinsics[..., 1, 1:2]
    cx = intrinsics[..., 0, 2:3]
    cy = intrinsics[..., 1, 2:3]
    return fx, fy, cx, cy


from einops import rearrange, repeat
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")
hom = lambda x, i=-1: torch.cat((x, torch.ones_like(x.unbind(i)[0].unsqueeze(i))), i)


def expand_as(x, y):
    if len(x.shape) == len(y.shape):
        return x

    for i in range(len(y.shape) - len(x.shape)):
        x = x.unsqueeze(-1)

    return x


def lift(x, y, z, intrinsics, homogeneous=False):
    """

    :param self:
    :param x: Shape (batch_size, num_points)
    :param y:
    :param z:
    :param intrinsics:
    :return:
    """
    fx, fy, cx, cy = parse_intrinsics(intrinsics)

    x_lift = (x - expand_as(cx, x)) / expand_as(fx, x) * z
    y_lift = (y - expand_as(cy, y)) / expand_as(fy, y) * z

    if homogeneous:
        return torch.stack((x_lift, y_lift, z, torch.ones_like(z).to(x.device)), dim=-1)
    else:
        return torch.stack((x_lift, y_lift, z), dim=-1)


def world_from_xy_depth(xy, depth, cam2world, intrinsics):
    batch_size, *_ = cam2world.shape

    x_cam = xy[..., 0]
    y_cam = xy[..., 1]
    z_cam = depth

    pixel_points_cam = lift(
        x_cam, y_cam, z_cam, intrinsics=intrinsics, homogeneous=True
    )
    world_coords = torch.einsum("b...ij,b...kj->b...ki", cam2world, pixel_points_cam)[
        ..., :3
    ]

    return world_coords


def get_ray_directions(xy, cam2world, intrinsics, normalize=True):
    z_cam = torch.ones(xy.shape[:-1]).to(xy.device)
    pixel_points = world_from_xy_depth(
        xy, z_cam, intrinsics=intrinsics, cam2world=cam2world
    )  # (batch, num_samples, 3)

    cam_pos = cam2world[..., :3, 3]
    ray_dirs = pixel_points - cam_pos[..., None, :]  # (batch, num_samples, 3)
    if normalize:
        ray_dirs = F.normalize(ray_dirs, dim=-1)
    return ray_dirs

from PIL import Image
def _load_16big_png_depth(depth_png) -> np.ndarray:
    with Image.open(depth_png) as depth_pil:
        # the image is stored with 16-bit depth but PIL reads it as I (32 bit).
        # we cast it to uint16, then reinterpret as float16, then cast to float32
        depth = (
            np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
            .astype(np.float32)
            .reshape((depth_pil.size[1], depth_pil.size[0]))
        )
    return depth
def _load_depth(path, scale_adjustment) -> np.ndarray:
    d = _load_16big_png_depth(path) * scale_adjustment
    d[~np.isfinite(d)] = 0.0
    return d[None]  # fake feature channel

# NOTE currently using CO3D V1 because they switch to NDC cameras in 2. TODO is to make conversion code (different intrinsics), verify pointclouds, and switch. 

class Co3DNoCams(torch.utils.data.Dataset):
    """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""

    def __init__(
        self,
        n_skip=1,
        num_trgt=1,
        res_factor=1,
        val=False,
        num_cat=1000,
        overfit=False,
        category=None,
        seq_query=None,
    ):
        if num_cat is None: num_cat=1000

        self.n_trgt=num_trgt
        self.of=overfit
        self.val=val
        self.num_skip=n_skip
        self.res_factor=res_factor

        self.base_path="/data/co3dhydrants/co3d/hydrants/"#os.environ['CO3D_ROOT']
        print(self.base_path)

        # Get sequences!
        from collections import defaultdict
        sequences = defaultdict(list)
        self.total_num_data=0
        self.all_frame_names=[]
        all_cats = [ "hydrant","teddybear","apple", "ball", "bench", "cake", "donut", "plant", "suitcase", "vase","backpack", "banana", "baseballbat", 
                     "baseballglove",  "bicycle", "book", "bottle", "bowl", "broccoli",  "car", "carrot", "cellphone", "chair", "couch", "cup",  
                     "frisbee", "hairdryer", "handbag", "hotdog", "keyboard", "kite", "laptop", "microwave", "motorcycle", "mouse", "orange", "parkingmeter", 
                     "pizza",  "remote", "sandwich", "skateboard", "stopsign",  "toaster", "toilet", "toybus", "toyplane", "toytrain", "toytruck", "tv", "umbrella",  "wineglass", ]

        seq_query = "106_12648_23157"
        for cat in (all_cats[:num_cat]) if category is None else [category]:

            print(cat)
            dataset = json.loads(gzip.GzipFile(os.path.join(self.base_path,cat,"frame_annotations.jgz"),"rb").read().decode("utf8"))
            val_amt = int(len(dataset)*.03)
            if seq_query is None: dataset = dataset[:-val_amt] if not val else dataset[-val_amt:]
            self.total_num_data+=len(dataset)
            for i,data in enumerate(dataset):
                if data["sequence_name"] in val_seqs and not val:continue
                if seq_query is not None and data["sequence_name"]!=seq_query:continue
                self.all_frame_names.append((data["sequence_name"],data["frame_number"]))
                sequences[data["sequence_name"]].append(data)
    
        if seq_query is not None: self.of=True
        sorted_seq={}
        for k,v in sequences.items():
            sorted_seq[k]=sorted(sequences[k],key=lambda x:x["frame_number"])
        if val:
            self.total_num_data_all=len(self.all_frame_names)
            self.total_all_frame_names=self.all_frame_names
            self.total_sorted_seq=sorted_seq
        self.seqs = sorted_seq
        print("done with dataloader init")

    def set_seq(self,seq_query):
        self.seqs={seq_query:self.total_sorted_seq[seq_query]}
        self.all_frame_names=[x for x in self.total_all_frame_names if x[0]==seq_query]
        self.total_num_data=len(self.all_frame_names)

    def sparsify(self, dict, sparsity):
        new_dict = {}
        if sparsity is None:
            return dict
        else:
            # Sample upper_limit pixel idcs at random.
            rand_idcs = np.random.choice(
                self.img_sidelength ** 2, size=sparsity, replace=False
            )
            for key in ["rgb", "uv"]:
                new_dict[key] = dict[key][rand_idcs]

            for key, v in dict.items():
                if key not in ["rgb", "uv"]:
                    new_dict[key] = dict[key]

            return new_dict

    def set_img_sidelength(self, new_img_sidelength):
        """For multi-resolution training: Updates the image sidelength with which images are loaded."""
        self.img_sidelength = new_img_sidelength
        for instance in self.all_instances:
            instance.set_img_sidelength(new_img_sidelength)

    def __len__(self):
        return ((self.total_num_data-(1+self.n_trgt)*(1+(max(self.num_skip) if type(self.num_skip)==list else self.num_skip)))
                if not self.of else 1)

    def __getitem__(self, idx,seq_query=None):
        #idx=1000

        context = []
        trgt = []
        post_input = []

        n_skip = (random.choice(self.num_skip) if type(self.num_skip)==list else self.num_skip) + 1

        if seq_query is None:
            try: 
                seq_name,frame_idx=self.all_frame_names[idx]
            except: 
                print(f"Out of bounds erorr at {idx}. Investigate.")
                return self[-2*n_skip*self.n_trgt if self.val else np.random.randint(len(self))]
        
        if seq_query is not None:
            frame_idx=idx
            seq_name = list(self.seqs.keys())[seq_query]
            all_frames= self.seqs[seq_name]
        else:
            all_frames=self.seqs[seq_name] if not self.of else self.seqs[random.choice(list(self.seqs.keys())[:int(self.of)])]

        if len(all_frames)<=self.n_trgt*n_skip or frame_idx >= (len(all_frames)-self.n_trgt*n_skip):
            frame_idx=0
            if len(all_frames)<=self.n_trgt*n_skip or frame_idx >= (len(all_frames)-self.n_trgt*n_skip):
                if len(all_frames)<=self.n_trgt*n_skip:
                    print(len(all_frames) ," frames < ",self.n_trgt*n_skip," queries")
                print("returning low/high")
                return self[-2*n_skip*self.n_trgt if self.val else np.random.randint(len(self))]
        start_idx = frame_idx 

        if self.of and 1: start_idx=0

        frames = all_frames[start_idx:start_idx+self.n_trgt*n_skip:n_skip]
        if np.random.rand()<.5 and not self.of and not self.val and idx!=1000: frames=frames[::-1]

        paths = [os.path.join(self.base_path,x["image"]["path"]) for x in frames]
        for path in paths:
            if not os.path.exists(path):
                print("path missing")
                return self[np.random.randint(len(self))]

        imgs=[torch.from_numpy(plt.imread(path)) for path in paths]

        Ks=[]
        c2ws=[]
        depths=[]
        for data in frames:

            # Below pose processing taken from co3d github issue
            p = data["viewpoint"]["principal_point"]
            f = data["viewpoint"]["focal_length"]
            h, w = data["image"]["size"]
            org_ratio=h/w
            K = np.eye(3)
            s = (min(h, w)) / 2
            K[0, 0] = f[0] * (w) / 2
            K[1, 1] = f[1] * (h) / 2
            K[0, 2] = -p[0] * s + (w) / 2
            K[1, 2] = -p[1] * s + (h) / 2

            # Normalize intrinsics to [-1,1]
            raw_K=[torch.from_numpy(K).clone(),[h,w]]
            K[:2] /= torch.tensor([w, h])[:, None]

            # not sure if this is correct
            K[0,0]*=h/w

            Ks.append(torch.from_numpy(K).float())

            R = np.asarray(data["viewpoint"]["R"]).T   # note the transpose here
            T = np.asarray(data["viewpoint"]["T"]) 
            pose = np.concatenate([R,T[:,None]],1)
            pose = torch.from_numpy( np.diag([-1,-1,1]).astype(np.float32) @ pose )# flip the direction of x,y axis
            tmp=torch.eye(4)
            tmp[:3,:4]=pose
            c2ws.append(tmp.inverse())

        Ks=torch.stack(Ks)
        c2w=torch.stack(c2ws).float()
        minx,miny=min([x.size(0) for x in imgs]),min([x.size(1) for x in imgs])

        imgs=[x[:minx,:miny].float() for x in imgs]

        s=1
        trgt={"rgb":(torch.stack(imgs).permute(0,3,1,2)/255)*2-1,"c2w":c2w,"intrinsics":Ks,"path":paths[-1],"org_ratio":org_ratio}
        #return common.make_sample(trgt,1/org_ratio,hires_factor=h,budget=192*640/(8//s),low_res=[160,96],hi_res=[1024//2,576//2])
        #return common.make_sample(trgt,1/org_ratio,hires_factor=h,budget=192*640/(8//s),low_res=[160,96],hi_res=[1024,576])
        #hi_res,low_res = [[640,1024],[160,256]] if self.of else [[320,192],[224,128]]
        hi_res,low_res = [[1024,640],[256,160]] if self.of else [[320,192],[224,128]]
        return common.make_sample(trgt,1/org_ratio,hires_factor=h,budget=192*640/(8//s),low_res=low_res,hi_res=hi_res)
        #return common.make_sample(trgt,1/org_ratio,hires_factor=h,budget=192*640/(8//s),low_res=[192,128],hi_res=[1024,576])
        return common.make_sample(trgt,102/156,med_factor=2)
        #low_res=torch.tensor([156,102])
        #return common.make_sample(trgt,(192,640),(int(3*64/self.res_factor),int(3*208/self.res_factor)),(int(320*1.0),int(1024*1.0)))
        #return common.make_sample(trgt,(low_res/self.res_factor).long().tolist(),(low_res*2).long().tolist(),(low_res*4).long().tolist())
